# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Module for constructing RNN Cells.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import math from tensorflow.contrib.compiler import jit from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.python.framework import dtypes from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl # pylint: disable=unused-import from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables # pylint: disable=unused-import from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest def _get_concat_variable(name, shape, dtype, num_shards): """Get a sharded variable concatenated into one tensor.""" sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) if len(sharded_variable) == 1: return sharded_variable[0] concat_name = name + "/concat" concat_full_name = vs.get_variable_scope().name + "/" + concat_name + ":0" for value in ops.get_collection(ops.GraphKeys.CONCATENATED_VARIABLES): if value.name == concat_full_name: return value concat_variable = array_ops.concat(sharded_variable, 0, name=concat_name) ops.add_to_collection(ops.GraphKeys.CONCATENATED_VARIABLES, concat_variable) return concat_variable def _get_sharded_variable(name, shape, dtype, num_shards): """Get a list of sharded variables with the given dtype.""" if num_shards > shape[0]: raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape, num_shards)) unit_shard_size = int(math.floor(shape[0] / num_shards)) remaining_rows = shape[0] - unit_shard_size * num_shards shards = [] for i in range(num_shards): current_size = unit_shard_size if i < remaining_rows: current_size += 1 shards.append( vs.get_variable( name + "_%d" % i, [current_size] + shape[1:], dtype=dtype)) return shards def _norm(g, b, inp, scope): shape = inp.get_shape()[-1:] gamma_init = init_ops.constant_initializer(g) beta_init = init_ops.constant_initializer(b) with vs.variable_scope(scope): # Initialize beta and gamma for use by layer_norm. vs.get_variable("gamma", shape=shape, initializer=gamma_init) vs.get_variable("beta", shape=shape, initializer=beta_init) normalized = layers.layer_norm(inp, reuse=True, scope=scope) return normalized class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): """Long short-term memory unit (LSTM) recurrent network cell. The default non-peephole implementation is based on: http://www.bioinf.jku.at/publications/older/2604.pdf S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. The peephole implementation is based on: https://research.google.com/pubs/archive/43905.pdf Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. The coupling of input and forget gate is based on: http://arxiv.org/pdf/1503.04069.pdf Greff et al. "LSTM: A Search Space Odyssey" The class uses optional peep-hole connections, and an optional projection layer. Layer normalization implementation is based on: https://arxiv.org/abs/1607.06450. "Layer Normalization" Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton and is applied before the internal nonlinearities. """ def __init__(self, num_units, use_peepholes=False, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, forget_bias=1.0, state_is_tuple=True, activation=math_ops.tanh, reuse=None, layer_norm=False, norm_gain=1.0, norm_shift=0.0): """Initialize the parameters for an LSTM cell. Args: num_units: int, The number of units in the LSTM cell use_peepholes: bool, set True to enable diagonal/peephole connections. initializer: (optional) The initializer to use for the weight and projection matrices. num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is provided, then the projected values are clipped elementwise to within `[-proj_clip, proj_clip]`. num_unit_shards: How to split the weight matrix. If >1, the weight matrix is stored across num_unit_shards. num_proj_shards: How to split the projection matrix. If >1, the projection matrix is stored across num_proj_shards. forget_bias: Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of the training. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. By default (False), they are concatenated along the column axis. This default behavior will soon be deprecated. activation: Activation function of the inner states. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. layer_norm: If `True`, layer normalization will be applied. norm_gain: float, The layer normalization gain initial value. If `layer_norm` has been set to `False`, this argument will be ignored. norm_shift: float, The layer normalization shift initial value. If `layer_norm` has been set to `False`, this argument will be ignored. """ super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) self._num_units = num_units self._use_peepholes = use_peepholes self._initializer = initializer self._num_proj = num_proj self._proj_clip = proj_clip self._num_unit_shards = num_unit_shards self._num_proj_shards = num_proj_shards self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation self._reuse = reuse self._layer_norm = layer_norm self._norm_gain = norm_gain self._norm_shift = norm_shift if num_proj: self._state_size = ( rnn_cell_impl.LSTMStateTuple(num_units, num_proj) if state_is_tuple else num_units + num_proj) self._output_size = num_proj else: self._state_size = ( rnn_cell_impl.LSTMStateTuple(num_units, num_units) if state_is_tuple else 2 * num_units) self._output_size = num_units @property def state_size(self): return self._state_size @property def output_size(self): return self._output_size def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: if `state_is_tuple` is False, this must be a state Tensor, `2-D, batch x state_size`. If `state_is_tuple` is True, this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the LSTM after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of LSTM after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ sigmoid = math_ops.sigmoid num_proj = self._num_units if self._num_proj is None else self._num_proj if self._state_is_tuple: (c_prev, m_prev) = state else: c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) dtype = inputs.dtype input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") concat_w = _get_concat_variable( "W", [input_size.value + num_proj, 3 * self._num_units], dtype, self._num_unit_shards) b = vs.get_variable( "B", shape=[3 * self._num_units], initializer=init_ops.zeros_initializer(), dtype=dtype) # j = new_input, f = forget_gate, o = output_gate cell_inputs = array_ops.concat([inputs, m_prev], 1) lstm_matrix = math_ops.matmul(cell_inputs, concat_w) # If layer nomalization is applied, do not add bias if not self._layer_norm: lstm_matrix = nn_ops.bias_add(lstm_matrix, b) j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1) # Apply layer normalization if self._layer_norm: j = _norm(self._norm_gain, self._norm_shift, j, "transform") f = _norm(self._norm_gain, self._norm_shift, f, "forget") o = _norm(self._norm_gain, self._norm_shift, o, "output") # Diagonal connections if self._use_peepholes: w_f_diag = vs.get_variable( "W_F_diag", shape=[self._num_units], dtype=dtype) w_o_diag = vs.get_variable( "W_O_diag", shape=[self._num_units], dtype=dtype) if self._use_peepholes: f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev) else: f_act = sigmoid(f + self._forget_bias) c = (f_act * c_prev + (1 - f_act) * self._activation(j)) # Apply layer normalization if self._layer_norm: c = _norm(self._norm_gain, self._norm_shift, c, "state") if self._use_peepholes: m = sigmoid(o + w_o_diag * c) * self._activation(c) else: m = sigmoid(o) * self._activation(c) if self._num_proj is not None: concat_w_proj = _get_concat_variable("W_P", [self._num_units, self._num_proj], dtype, self._num_proj_shards) m = math_ops.matmul(m, concat_w_proj) if self._proj_clip is not None: # pylint: disable=invalid-unary-operand-type m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) # pylint: enable=invalid-unary-operand-type new_state = ( rnn_cell_impl.LSTMStateTuple(c, m) if self._state_is_tuple else array_ops.concat([c, m], 1)) return m, new_state class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): """Time-Frequency Long short-term memory unit (LSTM) recurrent network cell. This implementation is based on: Tara N. Sainath and Bo Li "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures for LVCSR Tasks." submitted to INTERSPEECH, 2016. It uses peep-hole connections and optional cell clipping. """ def __init__(self, num_units, use_peepholes=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=1, reuse=None): """Initialize the parameters for an LSTM cell. Args: num_units: int, The number of units in the LSTM cell use_peepholes: bool, set True to enable diagonal/peephole connections. cell_clip: (optional) A float value, if provided the cell state is clipped by this value prior to the cell output activation. initializer: (optional) The initializer to use for the weight and projection matrices. num_unit_shards: int, How to split the weight matrix. If >1, the weight matrix is stored across num_unit_shards. forget_bias: float, Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of the training. feature_size: int, The size of the input feature the LSTM spans over. frequency_skip: int, The amount the LSTM filter is shifted by in frequency. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(TimeFreqLSTMCell, self).__init__(_reuse=reuse) self._num_units = num_units self._use_peepholes = use_peepholes self._cell_clip = cell_clip self._initializer = initializer self._num_unit_shards = num_unit_shards self._forget_bias = forget_bias self._feature_size = feature_size self._frequency_skip = frequency_skip self._state_size = 2 * num_units self._output_size = num_units self._reuse = reuse @property def output_size(self): return self._output_size @property def state_size(self): return self._state_size def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: state Tensor, 2D, batch x state_size. Returns: A tuple containing: - A 2D, batch x output_dim, Tensor representing the output of the LSTM after reading "inputs" when previous state was "state". Here output_dim is num_units. - A 2D, batch x state_size, Tensor representing the new state of LSTM after reading "inputs" when previous state was "state". Raises: ValueError: if an input_size was specified and the provided inputs have a different dimension. """ sigmoid = math_ops.sigmoid tanh = math_ops.tanh freq_inputs = self._make_tf_features(inputs) dtype = inputs.dtype actual_input_size = freq_inputs[0].get_shape().as_list()[1] concat_w = _get_concat_variable( "W", [actual_input_size + 2 * self._num_units, 4 * self._num_units], dtype, self._num_unit_shards) b = vs.get_variable( "B", shape=[4 * self._num_units], initializer=init_ops.zeros_initializer(), dtype=dtype) # Diagonal connections if self._use_peepholes: w_f_diag = vs.get_variable( "W_F_diag", shape=[self._num_units], dtype=dtype) w_i_diag = vs.get_variable( "W_I_diag", shape=[self._num_units], dtype=dtype) w_o_diag = vs.get_variable( "W_O_diag", shape=[self._num_units], dtype=dtype) # initialize the first freq state to be zero m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units], dtype) for fq in range(len(freq_inputs)): c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units], [-1, self._num_units]) m_prev = array_ops.slice(state, [0, (2 * fq + 1) * self._num_units], [-1, self._num_units]) # i = input_gate, j = new_input, f = forget_gate, o = output_gate cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq], 1) lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) i, j, f, o = array_ops.split( value=lstm_matrix, num_or_size_splits=4, axis=1) if self._use_peepholes: c = ( sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + sigmoid(i + w_i_diag * c_prev) * tanh(j)) else: c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j)) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type if self._use_peepholes: m = sigmoid(o + w_o_diag * c) * tanh(c) else: m = sigmoid(o) * tanh(c) m_prev_freq = m if fq == 0: state_out = array_ops.concat([c, m], 1) m_out = m else: state_out = array_ops.concat([state_out, c, m], 1) m_out = array_ops.concat([m_out, m], 1) return m_out, state_out def _make_tf_features(self, input_feat): """Make the frequency features. Args: input_feat: input Tensor, 2D, batch x num_units. Returns: A list of frequency features, with each element containing: - A 2D, batch x output_dim, Tensor representing the time-frequency feature for that frequency index. Here output_dim is feature_size. Raises: ValueError: if input_size cannot be inferred from static shape inference. """ input_size = input_feat.get_shape().with_rank(2)[-1].value if input_size is None: raise ValueError("Cannot infer input_size from static shape inference.") num_feats = int( (input_size - self._feature_size) / (self._frequency_skip)) + 1 freq_inputs = [] for f in range(num_feats): cur_input = array_ops.slice(input_feat, [0, f * self._frequency_skip], [-1, self._feature_size]) freq_inputs.append(cur_input) return freq_inputs class GridLSTMCell(rnn_cell_impl.RNNCell): """Grid Long short-term memory unit (LSTM) recurrent network cell. The default is based on: Nal Kalchbrenner, Ivo Danihelka and Alex Graves "Grid Long Short-Term Memory," Proc. ICLR 2016. http://arxiv.org/abs/1507.01526 When peephole connections are used, the implementation is based on: Tara N. Sainath and Bo Li "Modeling Time-Frequency Patterns with LSTM vs. Convolutional Architectures for LVCSR Tasks." submitted to INTERSPEECH, 2016. The code uses optional peephole connections, shared_weights and cell clipping. """ def __init__(self, num_units, use_peepholes=False, share_time_frequency_weights=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=None, num_frequency_blocks=None, start_freqindex_list=None, end_freqindex_list=None, couple_input_forget_gates=False, state_is_tuple=True, reuse=None): """Initialize the parameters for an LSTM cell. Args: num_units: int, The number of units in the LSTM cell use_peepholes: (optional) bool, default False. Set True to enable diagonal/peephole connections. share_time_frequency_weights: (optional) bool, default False. Set True to enable shared cell weights between time and frequency LSTMs. cell_clip: (optional) A float value, default None, if provided the cell state is clipped by this value prior to the cell output activation. initializer: (optional) The initializer to use for the weight and projection matrices, default None. num_unit_shards: (optional) int, default 1, How to split the weight matrix. If > 1,the weight matrix is stored across num_unit_shards. forget_bias: (optional) float, default 1.0, The initial bias of the forget gates, used to reduce the scale of forgetting at the beginning of the training. feature_size: (optional) int, default None, The size of the input feature the LSTM spans over. frequency_skip: (optional) int, default None, The amount the LSTM filter is shifted by in frequency. num_frequency_blocks: [required] A list of frequency blocks needed to cover the whole input feature splitting defined by start_freqindex_list and end_freqindex_list. start_freqindex_list: [optional], list of ints, default None, The starting frequency index for each frequency block. end_freqindex_list: [optional], list of ints, default None. The ending frequency index for each frequency block. couple_input_forget_gates: (optional) bool, default False, Whether to couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce model parameters and computation cost. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. By default (False), they are concatenated along the column axis. This default behavior will soon be deprecated. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. Raises: ValueError: if the num_frequency_blocks list is not specified """ super(GridLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) self._num_units = num_units self._use_peepholes = use_peepholes self._share_time_frequency_weights = share_time_frequency_weights self._couple_input_forget_gates = couple_input_forget_gates self._state_is_tuple = state_is_tuple self._cell_clip = cell_clip self._initializer = initializer self._num_unit_shards = num_unit_shards self._forget_bias = forget_bias self._feature_size = feature_size self._frequency_skip = frequency_skip self._start_freqindex_list = start_freqindex_list self._end_freqindex_list = end_freqindex_list self._num_frequency_blocks = num_frequency_blocks self._total_blocks = 0 self._reuse = reuse if self._num_frequency_blocks is None: raise ValueError("Must specify num_frequency_blocks") for block_index in range(len(self._num_frequency_blocks)): self._total_blocks += int(self._num_frequency_blocks[block_index]) if state_is_tuple: state_names = "" for block_index in range(len(self._num_frequency_blocks)): for freq_index in range(self._num_frequency_blocks[block_index]): name_prefix = "state_f%02d_b%02d" % (freq_index, block_index) state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) self._state_tuple_type = collections.namedtuple("GridLSTMStateTuple", state_names.strip(",")) self._state_size = self._state_tuple_type(*( [num_units, num_units] * self._total_blocks)) else: self._state_tuple_type = None self._state_size = num_units * self._total_blocks * 2 self._output_size = num_units * self._total_blocks * 2 @property def output_size(self): return self._output_size @property def state_size(self): return self._state_size @property def state_tuple_type(self): return self._state_tuple_type def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, [batch, feature_size]. state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the flag self._state_is_tuple. Returns: A tuple containing: - A 2D, [batch, output_dim], Tensor representing the output of the LSTM after reading "inputs" when previous state was "state". Here output_dim is num_units. - A 2D, [batch, state_size], Tensor representing the new state of LSTM after reading "inputs" when previous state was "state". Raises: ValueError: if an input_size was specified and the provided inputs have a different dimension. """ batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] freq_inputs = self._make_tf_features(inputs) m_out_lst = [] state_out_lst = [] for block in range(len(freq_inputs)): m_out_lst_current, state_out_lst_current = self._compute( freq_inputs[block], block, state, batch_size, state_is_tuple=self._state_is_tuple) m_out_lst.extend(m_out_lst_current) state_out_lst.extend(state_out_lst_current) if self._state_is_tuple: state_out = self._state_tuple_type(*state_out_lst) else: state_out = array_ops.concat(state_out_lst, 1) m_out = array_ops.concat(m_out_lst, 1) return m_out, state_out def _compute(self, freq_inputs, block, state, batch_size, state_prefix="state", state_is_tuple=True): """Run the actual computation of one step LSTM. Args: freq_inputs: list of Tensors, 2D, [batch, feature_size]. block: int, current frequency block index to process. state: Tensor or tuple of Tensors, 2D, [batch, state_size], it depends on the flag state_is_tuple. batch_size: int32, batch size. state_prefix: (optional) string, name prefix for states, defaults to "state". state_is_tuple: boolean, indicates whether the state is a tuple or Tensor. Returns: A tuple, containing: - A list of [batch, output_dim] Tensors, representing the output of the LSTM given the inputs and state. - A list of [batch, state_size] Tensors, representing the LSTM state values given the inputs and previous state. """ sigmoid = math_ops.sigmoid tanh = math_ops.tanh num_gates = 3 if self._couple_input_forget_gates else 4 dtype = freq_inputs[0].dtype actual_input_size = freq_inputs[0].get_shape().as_list()[1] concat_w_f = _get_concat_variable( "W_f_%d" % block, [actual_input_size + 2 * self._num_units, num_gates * self._num_units], dtype, self._num_unit_shards) b_f = vs.get_variable( "B_f_%d" % block, shape=[num_gates * self._num_units], initializer=init_ops.zeros_initializer(), dtype=dtype) if not self._share_time_frequency_weights: concat_w_t = _get_concat_variable("W_t_%d" % block, [ actual_input_size + 2 * self._num_units, num_gates * self._num_units ], dtype, self._num_unit_shards) b_t = vs.get_variable( "B_t_%d" % block, shape=[num_gates * self._num_units], initializer=init_ops.zeros_initializer(), dtype=dtype) if self._use_peepholes: # Diagonal connections if not self._couple_input_forget_gates: w_f_diag_freqf = vs.get_variable( "W_F_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) w_f_diag_freqt = vs.get_variable( "W_F_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) w_i_diag_freqf = vs.get_variable( "W_I_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) w_i_diag_freqt = vs.get_variable( "W_I_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) w_o_diag_freqf = vs.get_variable( "W_O_diag_freqf_%d" % block, shape=[self._num_units], dtype=dtype) w_o_diag_freqt = vs.get_variable( "W_O_diag_freqt_%d" % block, shape=[self._num_units], dtype=dtype) if not self._share_time_frequency_weights: if not self._couple_input_forget_gates: w_f_diag_timef = vs.get_variable( "W_F_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) w_f_diag_timet = vs.get_variable( "W_F_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) w_i_diag_timef = vs.get_variable( "W_I_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) w_i_diag_timet = vs.get_variable( "W_I_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) w_o_diag_timef = vs.get_variable( "W_O_diag_timef_%d" % block, shape=[self._num_units], dtype=dtype) w_o_diag_timet = vs.get_variable( "W_O_diag_timet_%d" % block, shape=[self._num_units], dtype=dtype) # initialize the first freq state to be zero m_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) c_prev_freq = array_ops.zeros([batch_size, self._num_units], dtype) for freq_index in range(len(freq_inputs)): if state_is_tuple: name_prefix = "%s_f%02d_b%02d" % (state_prefix, freq_index, block) c_prev_time = getattr(state, name_prefix + "_c") m_prev_time = getattr(state, name_prefix + "_m") else: c_prev_time = array_ops.slice( state, [0, 2 * freq_index * self._num_units], [-1, self._num_units]) m_prev_time = array_ops.slice( state, [0, (2 * freq_index + 1) * self._num_units], [-1, self._num_units]) # i = input_gate, j = new_input, f = forget_gate, o = output_gate cell_inputs = array_ops.concat( [freq_inputs[freq_index], m_prev_time, m_prev_freq], 1) # F-LSTM lstm_matrix_freq = nn_ops.bias_add( math_ops.matmul(cell_inputs, concat_w_f), b_f) if self._couple_input_forget_gates: i_freq, j_freq, o_freq = array_ops.split( value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) f_freq = None else: i_freq, j_freq, f_freq, o_freq = array_ops.split( value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1) # T-LSTM if self._share_time_frequency_weights: i_time = i_freq j_time = j_freq f_time = f_freq o_time = o_freq else: lstm_matrix_time = nn_ops.bias_add( math_ops.matmul(cell_inputs, concat_w_t), b_t) if self._couple_input_forget_gates: i_time, j_time, o_time = array_ops.split( value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) f_time = None else: i_time, j_time, f_time, o_time = array_ops.split( value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1) # F-LSTM c_freq # input gate activations if self._use_peepholes: i_freq_g = sigmoid(i_freq + w_i_diag_freqf * c_prev_freq + w_i_diag_freqt * c_prev_time) else: i_freq_g = sigmoid(i_freq) # forget gate activations if self._couple_input_forget_gates: f_freq_g = 1.0 - i_freq_g else: if self._use_peepholes: f_freq_g = sigmoid(f_freq + self._forget_bias + w_f_diag_freqf * c_prev_freq + w_f_diag_freqt * c_prev_time) else: f_freq_g = sigmoid(f_freq + self._forget_bias) # cell state c_freq = f_freq_g * c_prev_freq + i_freq_g * tanh(j_freq) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type c_freq = clip_ops.clip_by_value(c_freq, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type # T-LSTM c_freq # input gate activations if self._use_peepholes: if self._share_time_frequency_weights: i_time_g = sigmoid(i_time + w_i_diag_freqf * c_prev_freq + w_i_diag_freqt * c_prev_time) else: i_time_g = sigmoid(i_time + w_i_diag_timef * c_prev_freq + w_i_diag_timet * c_prev_time) else: i_time_g = sigmoid(i_time) # forget gate activations if self._couple_input_forget_gates: f_time_g = 1.0 - i_time_g else: if self._use_peepholes: if self._share_time_frequency_weights: f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_freqf * c_prev_freq + w_f_diag_freqt * c_prev_time) else: f_time_g = sigmoid(f_time + self._forget_bias + w_f_diag_timef * c_prev_freq + w_f_diag_timet * c_prev_time) else: f_time_g = sigmoid(f_time + self._forget_bias) # cell state c_time = f_time_g * c_prev_time + i_time_g * tanh(j_time) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type c_time = clip_ops.clip_by_value(c_time, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type # F-LSTM m_freq if self._use_peepholes: m_freq = sigmoid(o_freq + w_o_diag_freqf * c_freq + w_o_diag_freqt * c_time) * tanh(c_freq) else: m_freq = sigmoid(o_freq) * tanh(c_freq) # T-LSTM m_time if self._use_peepholes: if self._share_time_frequency_weights: m_time = sigmoid(o_time + w_o_diag_freqf * c_freq + w_o_diag_freqt * c_time) * tanh(c_time) else: m_time = sigmoid(o_time + w_o_diag_timef * c_freq + w_o_diag_timet * c_time) * tanh(c_time) else: m_time = sigmoid(o_time) * tanh(c_time) m_prev_freq = m_freq c_prev_freq = c_freq # Concatenate the outputs for T-LSTM and F-LSTM for each shift if freq_index == 0: state_out_lst = [c_time, m_time] m_out_lst = [m_time, m_freq] else: state_out_lst.extend([c_time, m_time]) m_out_lst.extend([m_time, m_freq]) return m_out_lst, state_out_lst def _make_tf_features(self, input_feat, slice_offset=0): """Make the frequency features. Args: input_feat: input Tensor, 2D, [batch, num_units]. slice_offset: (optional) Python int, default 0, the slicing offset is only used for the backward processing in the BidirectionalGridLSTMCell. It specifies a different starting point instead of always 0 to enable the forward and backward processing look at different frequency blocks. Returns: A list of frequency features, with each element containing: - A 2D, [batch, output_dim], Tensor representing the time-frequency feature for that frequency index. Here output_dim is feature_size. Raises: ValueError: if input_size cannot be inferred from static shape inference. """ input_size = input_feat.get_shape().with_rank(2)[-1].value if input_size is None: raise ValueError("Cannot infer input_size from static shape inference.") if slice_offset > 0: # Padding to the end inputs = array_ops.pad(input_feat, array_ops.constant( [0, 0, 0, slice_offset], shape=[2, 2], dtype=dtypes.int32), "CONSTANT") elif slice_offset < 0: # Padding to the front inputs = array_ops.pad(input_feat, array_ops.constant( [0, 0, -slice_offset, 0], shape=[2, 2], dtype=dtypes.int32), "CONSTANT") slice_offset = 0 else: inputs = input_feat freq_inputs = [] if not self._start_freqindex_list: if len(self._num_frequency_blocks) != 1: raise ValueError("Length of num_frequency_blocks" " is not 1, but instead is %d", len(self._num_frequency_blocks)) num_feats = int( (input_size - self._feature_size) / (self._frequency_skip)) + 1 if num_feats != self._num_frequency_blocks[0]: raise ValueError( "Invalid num_frequency_blocks, requires %d but gets %d, please" " check the input size and filter config are correct." % (self._num_frequency_blocks[0], num_feats)) block_inputs = [] for f in range(num_feats): cur_input = array_ops.slice( inputs, [0, slice_offset + f * self._frequency_skip], [-1, self._feature_size]) block_inputs.append(cur_input) freq_inputs.append(block_inputs) else: if len(self._start_freqindex_list) != len(self._end_freqindex_list): raise ValueError("Length of start and end freqindex_list" " does not match %d %d", len(self._start_freqindex_list), len(self._end_freqindex_list)) if len(self._num_frequency_blocks) != len(self._start_freqindex_list): raise ValueError("Length of num_frequency_blocks" " is not equal to start_freqindex_list %d %d", len(self._num_frequency_blocks), len(self._start_freqindex_list)) for b in range(len(self._start_freqindex_list)): start_index = self._start_freqindex_list[b] end_index = self._end_freqindex_list[b] cur_size = end_index - start_index block_feats = int( (cur_size - self._feature_size) / (self._frequency_skip)) + 1 if block_feats != self._num_frequency_blocks[b]: raise ValueError( "Invalid num_frequency_blocks, requires %d but gets %d, please" " check the input size and filter config are correct." % (self._num_frequency_blocks[b], block_feats)) block_inputs = [] for f in range(block_feats): cur_input = array_ops.slice( inputs, [0, start_index + slice_offset + f * self._frequency_skip], [-1, self._feature_size]) block_inputs.append(cur_input) freq_inputs.append(block_inputs) return freq_inputs class BidirectionalGridLSTMCell(GridLSTMCell): """Bidirectional GridLstm cell. The bidirection connection is only used in the frequency direction, which hence doesn't affect the time direction's real-time processing that is required for online recognition systems. The current implementation uses different weights for the two directions. """ def __init__(self, num_units, use_peepholes=False, share_time_frequency_weights=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, feature_size=None, frequency_skip=None, num_frequency_blocks=None, start_freqindex_list=None, end_freqindex_list=None, couple_input_forget_gates=False, backward_slice_offset=0, reuse=None): """Initialize the parameters for an LSTM cell. Args: num_units: int, The number of units in the LSTM cell use_peepholes: (optional) bool, default False. Set True to enable diagonal/peephole connections. share_time_frequency_weights: (optional) bool, default False. Set True to enable shared cell weights between time and frequency LSTMs. cell_clip: (optional) A float value, default None, if provided the cell state is clipped by this value prior to the cell output activation. initializer: (optional) The initializer to use for the weight and projection matrices, default None. num_unit_shards: (optional) int, default 1, How to split the weight matrix. If > 1,the weight matrix is stored across num_unit_shards. forget_bias: (optional) float, default 1.0, The initial bias of the forget gates, used to reduce the scale of forgetting at the beginning of the training. feature_size: (optional) int, default None, The size of the input feature the LSTM spans over. frequency_skip: (optional) int, default None, The amount the LSTM filter is shifted by in frequency. num_frequency_blocks: [required] A list of frequency blocks needed to cover the whole input feature splitting defined by start_freqindex_list and end_freqindex_list. start_freqindex_list: [optional], list of ints, default None, The starting frequency index for each frequency block. end_freqindex_list: [optional], list of ints, default None. The ending frequency index for each frequency block. couple_input_forget_gates: (optional) bool, default False, Whether to couple the input and forget gates, i.e. f_gate = 1.0 - i_gate, to reduce model parameters and computation cost. backward_slice_offset: (optional) int32, default 0, the starting offset to slice the feature for backward processing. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(BidirectionalGridLSTMCell, self).__init__( num_units, use_peepholes, share_time_frequency_weights, cell_clip, initializer, num_unit_shards, forget_bias, feature_size, frequency_skip, num_frequency_blocks, start_freqindex_list, end_freqindex_list, couple_input_forget_gates, True, reuse) self._backward_slice_offset = int(backward_slice_offset) state_names = "" for direction in ["fwd", "bwd"]: for block_index in range(len(self._num_frequency_blocks)): for freq_index in range(self._num_frequency_blocks[block_index]): name_prefix = "%s_state_f%02d_b%02d" % (direction, freq_index, block_index) state_names += ("%s_c, %s_m," % (name_prefix, name_prefix)) self._state_tuple_type = collections.namedtuple( "BidirectionalGridLSTMStateTuple", state_names.strip(",")) self._state_size = self._state_tuple_type(*( [num_units, num_units] * self._total_blocks * 2)) self._output_size = 2 * num_units * self._total_blocks * 2 def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, [batch, num_units]. state: tuple of Tensors, 2D, [batch, state_size]. Returns: A tuple containing: - A 2D, [batch, output_dim], Tensor representing the output of the LSTM after reading "inputs" when previous state was "state". Here output_dim is num_units. - A 2D, [batch, state_size], Tensor representing the new state of LSTM after reading "inputs" when previous state was "state". Raises: ValueError: if an input_size was specified and the provided inputs have a different dimension. """ batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] fwd_inputs = self._make_tf_features(inputs) if self._backward_slice_offset: bwd_inputs = self._make_tf_features(inputs, self._backward_slice_offset) else: bwd_inputs = fwd_inputs # Forward processing with vs.variable_scope("fwd"): fwd_m_out_lst = [] fwd_state_out_lst = [] for block in range(len(fwd_inputs)): fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute( fwd_inputs[block], block, state, batch_size, state_prefix="fwd_state", state_is_tuple=True) fwd_m_out_lst.extend(fwd_m_out_lst_current) fwd_state_out_lst.extend(fwd_state_out_lst_current) # Backward processing bwd_m_out_lst = [] bwd_state_out_lst = [] with vs.variable_scope("bwd"): for block in range(len(bwd_inputs)): # Reverse the blocks bwd_inputs_reverse = bwd_inputs[block][::-1] bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute( bwd_inputs_reverse, block, state, batch_size, state_prefix="bwd_state", state_is_tuple=True) bwd_m_out_lst.extend(bwd_m_out_lst_current) bwd_state_out_lst.extend(bwd_state_out_lst_current) state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst)) # Outputs are always concated as it is never used separately. m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1) return m_out, state_out # pylint: disable=protected-access _Linear = core_rnn_cell._Linear # pylint: disable=invalid-name # pylint: enable=protected-access class AttentionCellWrapper(rnn_cell_impl.RNNCell): """Basic attention cell wrapper. Implementation based on https://arxiv.org/abs/1409.0473. """ def __init__(self, cell, attn_length, attn_size=None, attn_vec_size=None, input_size=None, state_is_tuple=True, reuse=None): """Create a cell with attention. Args: cell: an RNNCell, an attention is added to it. attn_length: integer, the size of an attention window. attn_size: integer, the size of an attention vector. Equal to cell.output_size by default. attn_vec_size: integer, the number of convolutional features calculated on attention state and a size of the hidden layer built from base cell state. Equal attn_size to by default. input_size: integer, the size of a hidden linear layer, built from inputs and attention. Derived from the input tensor by default. state_is_tuple: If True, accepted and returned states are n-tuples, where `n = len(cells)`. By default (False), the states are all concatenated along the column axis. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. Raises: TypeError: if cell is not an RNNCell. ValueError: if cell returns a state tuple but the flag `state_is_tuple` is `False` or if attn_length is zero or less. """ super(AttentionCellWrapper, self).__init__(_reuse=reuse) if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access raise TypeError("The parameter cell is not RNNCell.") if nest.is_sequence(cell.state_size) and not state_is_tuple: raise ValueError( "Cell returns tuple of states, but the flag " "state_is_tuple is not set. State size is: %s" % str(cell.state_size)) if attn_length <= 0: raise ValueError( "attn_length should be greater than zero, got %s" % str(attn_length)) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) if attn_size is None: attn_size = cell.output_size if attn_vec_size is None: attn_vec_size = attn_size self._state_is_tuple = state_is_tuple self._cell = cell self._attn_vec_size = attn_vec_size self._input_size = input_size self._attn_size = attn_size self._attn_length = attn_length self._reuse = reuse self._linear1 = None self._linear2 = None self._linear3 = None @property def state_size(self): size = (self._cell.state_size, self._attn_size, self._attn_size * self._attn_length) if self._state_is_tuple: return size else: return sum(list(size)) @property def output_size(self): return self._attn_size def call(self, inputs, state): """Long short-term memory cell with attention (LSTMA).""" if self._state_is_tuple: state, attns, attn_states = state else: states = state state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size]) attns = array_ops.slice(states, [0, self._cell.state_size], [-1, self._attn_size]) attn_states = array_ops.slice( states, [0, self._cell.state_size + self._attn_size], [-1, self._attn_size * self._attn_length]) attn_states = array_ops.reshape(attn_states, [-1, self._attn_length, self._attn_size]) input_size = self._input_size if input_size is None: input_size = inputs.get_shape().as_list()[1] if self._linear1 is None: self._linear1 = _Linear([inputs, attns], input_size, True) inputs = self._linear1([inputs, attns]) cell_output, new_state = self._cell(inputs, state) if self._state_is_tuple: new_state_cat = array_ops.concat(nest.flatten(new_state), 1) else: new_state_cat = new_state new_attns, new_attn_states = self._attention(new_state_cat, attn_states) with vs.variable_scope("attn_output_projection"): if self._linear2 is None: self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True) output = self._linear2([cell_output, new_attns]) new_attn_states = array_ops.concat( [new_attn_states, array_ops.expand_dims(output, 1)], 1) new_attn_states = array_ops.reshape( new_attn_states, [-1, self._attn_length * self._attn_size]) new_state = (new_state, new_attns, new_attn_states) if not self._state_is_tuple: new_state = array_ops.concat(list(new_state), 1) return output, new_state def _attention(self, query, attn_states): conv2d = nn_ops.conv2d reduce_sum = math_ops.reduce_sum softmax = nn_ops.softmax tanh = math_ops.tanh with vs.variable_scope("attention"): k = vs.get_variable("attn_w", [1, 1, self._attn_size, self._attn_vec_size]) v = vs.get_variable("attn_v", [self._attn_vec_size]) hidden = array_ops.reshape(attn_states, [-1, self._attn_length, 1, self._attn_size]) hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") if self._linear3 is None: self._linear3 = _Linear(query, self._attn_vec_size, True) y = self._linear3(query) y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) a = softmax(s) d = reduce_sum( array_ops.reshape(a, [-1, self._attn_length, 1, 1]) * hidden, [1, 2]) new_attns = array_ops.reshape(d, [-1, self._attn_size]) new_attn_states = array_ops.slice(attn_states, [0, 1, 0], [-1, -1, -1]) return new_attns, new_attn_states class HighwayWrapper(rnn_cell_impl.RNNCell): """RNNCell wrapper that adds highway connection on cell input and output. Based on: R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks", arXiv preprint arXiv:1505.00387, 2015. https://arxiv.org/abs/1505.00387 """ def __init__(self, cell, couple_carry_transform_gates=True, carry_bias_init=1.0): """Constructs a `HighwayWrapper` for `cell`. Args: cell: An instance of `RNNCell`. couple_carry_transform_gates: boolean, should the Carry and Transform gate be coupled. carry_bias_init: float, carry gates bias initialization. """ self._cell = cell self._couple_carry_transform_gates = couple_carry_transform_gates self._carry_bias_init = carry_bias_init @property def state_size(self): return self._cell.state_size @property def output_size(self): return self._cell.output_size def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): return self._cell.zero_state(batch_size, dtype) def _highway(self, inp, out): input_size = inp.get_shape().with_rank(2)[1].value carry_weight = vs.get_variable("carry_w", [input_size, input_size]) carry_bias = vs.get_variable( "carry_b", [input_size], initializer=init_ops.constant_initializer(self._carry_bias_init)) carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias)) if self._couple_carry_transform_gates: transform = 1 - carry else: transform_weight = vs.get_variable("transform_w", [input_size, input_size]) transform_bias = vs.get_variable( "transform_b", [input_size], initializer=init_ops.constant_initializer(-self._carry_bias_init)) transform = math_ops.sigmoid( nn_ops.xw_plus_b(inp, transform_weight, transform_bias)) return inp * carry + out * transform def __call__(self, inputs, state, scope=None): """Run the cell and add its inputs to its outputs. Args: inputs: cell inputs. state: cell state. scope: optional cell scope. Returns: Tuple of cell outputs and new state. Raises: TypeError: If cell inputs and outputs have different structure (type). ValueError: If cell inputs and outputs have different structure (value). """ outputs, new_state = self._cell(inputs, state, scope=scope) nest.assert_same_structure(inputs, outputs) # Ensure shapes match def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) nest.map_structure(assert_shape_match, inputs, outputs) res_outputs = nest.map_structure(self._highway, inputs, outputs) return (res_outputs, new_state) class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): """LSTM unit with layer normalization and recurrent dropout. This class adds layer normalization and recurrent dropout to a basic LSTM unit. Layer normalization implementation is based on: https://arxiv.org/abs/1607.06450. "Layer Normalization" Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton and is applied before the internal nonlinearities. Recurrent dropout is base on: https://arxiv.org/abs/1603.05118 "Recurrent Dropout without Memory Loss" Stanislau Semeniuta, Aliaksei Severyn, Erhardt Barth. """ def __init__(self, num_units, forget_bias=1.0, input_size=None, activation=math_ops.tanh, layer_norm=True, norm_gain=1.0, norm_shift=0.0, dropout_keep_prob=1.0, dropout_prob_seed=None, reuse=None): """Initializes the basic LSTM cell. Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). input_size: Deprecated and unused. activation: Activation function of the inner states. layer_norm: If `True`, layer normalization will be applied. norm_gain: float, The layer normalization gain initial value. If `layer_norm` has been set to `False`, this argument will be ignored. norm_shift: float, The layer normalization shift initial value. If `layer_norm` has been set to `False`, this argument will be ignored. dropout_keep_prob: unit Tensor or float between 0 and 1 representing the recurrent dropout probability value. If float and 1.0, no dropout will be applied. dropout_prob_seed: (optional) integer, the randomness seed. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated.", self) self._num_units = num_units self._activation = activation self._forget_bias = forget_bias self._keep_prob = dropout_keep_prob self._seed = dropout_prob_seed self._layer_norm = layer_norm self._norm_gain = norm_gain self._norm_shift = norm_shift self._reuse = reuse @property def state_size(self): return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) @property def output_size(self): return self._num_units def _norm(self, inp, scope, dtype=dtypes.float32): shape = inp.get_shape()[-1:] gamma_init = init_ops.constant_initializer(self._norm_gain) beta_init = init_ops.constant_initializer(self._norm_shift) with vs.variable_scope(scope): # Initialize beta and gamma for use by layer_norm. vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype) vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype) normalized = layers.layer_norm(inp, reuse=True, scope=scope) return normalized def _linear(self, args): out_size = 4 * self._num_units proj_size = args.get_shape()[-1] dtype = args.dtype weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype) out = math_ops.matmul(args, weights) if not self._layer_norm: bias = vs.get_variable("bias", [out_size], dtype=dtype) out = nn_ops.bias_add(out, bias) return out def call(self, inputs, state): """LSTM cell with layer normalization and recurrent dropout.""" c, h = state args = array_ops.concat([inputs, h], 1) concat = self._linear(args) dtype = args.dtype i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) if self._layer_norm: i = self._norm(i, "input", dtype=dtype) j = self._norm(j, "transform", dtype=dtype) f = self._norm(f, "forget", dtype=dtype) o = self._norm(o, "output", dtype=dtype) g = self._activation(j) if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: g = nn_ops.dropout(g, self._keep_prob, seed=self._seed) new_c = ( c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) if self._layer_norm: new_c = self._norm(new_c, "state", dtype=dtype) new_h = self._activation(new_c) * math_ops.sigmoid(o) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state class NASCell(rnn_cell_impl.RNNCell): """Neural Architecture Search (NAS) recurrent network cell. This implements the recurrent cell from the paper: https://arxiv.org/abs/1611.01578 Barret Zoph and Quoc V. Le. "Neural Architecture Search with Reinforcement Learning" Proc. ICLR 2017. The class uses an optional projection layer. """ def __init__(self, num_units, num_proj=None, use_biases=False, reuse=None): """Initialize the parameters for a NAS cell. Args: num_units: int, The number of units in the NAS cell num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. use_biases: (optional) bool, If True then use biases within the cell. This is False by default. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(NASCell, self).__init__(_reuse=reuse) self._num_units = num_units self._num_proj = num_proj self._use_biases = use_biases self._reuse = reuse if num_proj is not None: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) self._output_size = num_proj else: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) self._output_size = num_units @property def state_size(self): return self._state_size @property def output_size(self): return self._output_size def call(self, inputs, state): """Run one step of NAS Cell. Args: inputs: input Tensor, 2D, batch x num_units. state: This must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the NAS Cell after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of NAS Cell after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ sigmoid = math_ops.sigmoid tanh = math_ops.tanh relu = nn_ops.relu num_proj = self._num_units if self._num_proj is None else self._num_proj (c_prev, m_prev) = state dtype = inputs.dtype input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") # Variables for the NAS cell. W_m is all matrices multiplying the # hiddenstate and W_inputs is all matrices multiplying the inputs. concat_w_m = vs.get_variable("recurrent_kernel", [num_proj, 8 * self._num_units], dtype) concat_w_inputs = vs.get_variable( "kernel", [input_size.value, 8 * self._num_units], dtype) m_matrix = math_ops.matmul(m_prev, concat_w_m) inputs_matrix = math_ops.matmul(inputs, concat_w_inputs) if self._use_biases: b = vs.get_variable( "bias", shape=[8 * self._num_units], initializer=init_ops.zeros_initializer(), dtype=dtype) m_matrix = nn_ops.bias_add(m_matrix, b) # The NAS cell branches into 8 different splits for both the hiddenstate # and the input m_matrix_splits = array_ops.split( axis=1, num_or_size_splits=8, value=m_matrix) inputs_matrix_splits = array_ops.split( axis=1, num_or_size_splits=8, value=inputs_matrix) # First layer layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0]) layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1]) layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2]) layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3]) layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4]) layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5]) layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6]) layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7]) # Second layer l2_0 = tanh(layer1_0 * layer1_1) l2_1 = tanh(layer1_2 + layer1_3) l2_2 = tanh(layer1_4 * layer1_5) l2_3 = sigmoid(layer1_6 + layer1_7) # Inject the cell l2_0 = tanh(l2_0 + c_prev) # Third layer l3_0_pre = l2_0 * l2_1 new_c = l3_0_pre # create new cell l3_0 = l3_0_pre l3_1 = tanh(l2_2 + l2_3) # Final layer new_m = tanh(l3_0 * l3_1) # Projection layer if specified if self._num_proj is not None: concat_w_proj = vs.get_variable("projection_weights", [self._num_units, self._num_proj], dtype) new_m = math_ops.matmul(new_m, concat_w_proj) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_m) return new_m, new_state class UGRNNCell(rnn_cell_impl.RNNCell): """Update Gate Recurrent Neural Network (UGRNN) cell. Compromise between a LSTM/GRU and a vanilla RNN. There is only one gate, and that is to determine whether the unit should be integrating or computing instantaneously. This is the recurrent idea of the feedforward Highway Network. This implements the recurrent cell from the paper: https://arxiv.org/abs/1611.09913 Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo. "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. """ def __init__(self, num_units, initializer=None, forget_bias=1.0, activation=math_ops.tanh, reuse=None): """Initialize the parameters for an UGRNN cell. Args: num_units: int, The number of units in the UGRNN cell initializer: (optional) The initializer to use for the weight matrices. forget_bias: (optional) float, default 1.0, The initial bias of the forget gate, used to reduce the scale of forgetting at the beginning of the training. activation: (optional) Activation function of the inner states. Default is `tf.tanh`. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(UGRNNCell, self).__init__(_reuse=reuse) self._num_units = num_units self._initializer = initializer self._forget_bias = forget_bias self._activation = activation self._reuse = reuse self._linear = None @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def call(self, inputs, state): """Run one step of UGRNN. Args: inputs: input Tensor, 2D, batch x input size. state: state Tensor, 2D, batch x num units. Returns: new_output: batch x num units, Tensor representing the output of the UGRNN after reading `inputs` when previous state was `state`. Identical to `new_state`. new_state: batch x num units, Tensor representing the state of the UGRNN after reading `inputs` when previous state was `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ sigmoid = math_ops.sigmoid input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") with vs.variable_scope( vs.get_variable_scope(), initializer=self._initializer): cell_inputs = array_ops.concat([inputs, state], 1) if self._linear is None: self._linear = _Linear(cell_inputs, 2 * self._num_units, True) rnn_matrix = self._linear(cell_inputs) [g_act, c_act] = array_ops.split( axis=1, num_or_size_splits=2, value=rnn_matrix) c = self._activation(c_act) g = sigmoid(g_act + self._forget_bias) new_state = g * state + (1.0 - g) * c new_output = new_state return new_output, new_state class IntersectionRNNCell(rnn_cell_impl.RNNCell): """Intersection Recurrent Neural Network (+RNN) cell. Architecture with coupled recurrent gate as well as coupled depth gate, designed to improve information flow through stacked RNNs. As the architecture uses depth gating, the dimensionality of the depth output (y) also should not change through depth (input size == output size). To achieve this, the first layer of a stacked Intersection RNN projects the inputs to N (num units) dimensions. Therefore when initializing an IntersectionRNNCell, one should set `num_in_proj = N` for the first layer and use default settings for subsequent layers. This implements the recurrent cell from the paper: https://arxiv.org/abs/1611.09913 Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo. "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017. The Intersection RNN is built for use in deeply stacked RNNs so it may not achieve best performance with depth 1. """ def __init__(self, num_units, num_in_proj=None, initializer=None, forget_bias=1.0, y_activation=nn_ops.relu, reuse=None): """Initialize the parameters for an +RNN cell. Args: num_units: int, The number of units in the +RNN cell num_in_proj: (optional) int, The input dimensionality for the RNN. If creating the first layer of an +RNN, this should be set to `num_units`. Otherwise, this should be set to `None` (default). If `None`, dimensionality of `inputs` should be equal to `num_units`, otherwise ValueError is thrown. initializer: (optional) The initializer to use for the weight matrices. forget_bias: (optional) float, default 1.0, The initial bias of the forget gates, used to reduce the scale of forgetting at the beginning of the training. y_activation: (optional) Activation function of the states passed through depth. Default is 'tf.nn.relu`. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(IntersectionRNNCell, self).__init__(_reuse=reuse) self._num_units = num_units self._initializer = initializer self._forget_bias = forget_bias self._num_input_proj = num_in_proj self._y_activation = y_activation self._reuse = reuse self._linear1 = None self._linear2 = None @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def call(self, inputs, state): """Run one step of the Intersection RNN. Args: inputs: input Tensor, 2D, batch x input size. state: state Tensor, 2D, batch x num units. Returns: new_y: batch x num units, Tensor representing the output of the +RNN after reading `inputs` when previous state was `state`. new_state: batch x num units, Tensor representing the state of the +RNN after reading `inputs` when previous state was `state`. Raises: ValueError: If input size cannot be inferred from `inputs` via static shape inference. ValueError: If input size != output size (these must be equal when using the Intersection RNN). """ sigmoid = math_ops.sigmoid tanh = math_ops.tanh input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") with vs.variable_scope( vs.get_variable_scope(), initializer=self._initializer): # read-in projections (should be used for first layer in deep +RNN # to transform size of inputs from I --> N) if input_size.value != self._num_units: if self._num_input_proj: with vs.variable_scope("in_projection"): if self._linear1 is None: self._linear1 = _Linear(inputs, self._num_units, True) inputs = self._linear1(inputs) else: raise ValueError("Must have input size == output size for " "Intersection RNN. To fix, num_in_proj should " "be set to num_units at cell init.") n_dim = i_dim = self._num_units cell_inputs = array_ops.concat([inputs, state], 1) if self._linear2 is None: self._linear2 = _Linear(cell_inputs, 2 * n_dim + 2 * i_dim, True) rnn_matrix = self._linear2(cell_inputs) gh_act = rnn_matrix[:, :n_dim] # b x n h_act = rnn_matrix[:, n_dim:2 * n_dim] # b x n gy_act = rnn_matrix[:, 2 * n_dim:2 * n_dim + i_dim] # b x i y_act = rnn_matrix[:, 2 * n_dim + i_dim:2 * n_dim + 2 * i_dim] # b x i h = tanh(h_act) y = self._y_activation(y_act) gh = sigmoid(gh_act + self._forget_bias) gy = sigmoid(gy_act + self._forget_bias) new_state = gh * state + (1.0 - gh) * h # passed thru time new_y = gy * inputs + (1.0 - gy) * y # passed thru depth return new_y, new_state _REGISTERED_OPS = None class CompiledWrapper(rnn_cell_impl.RNNCell): """Wraps step execution in an XLA JIT scope.""" def __init__(self, cell, compile_stateful=False): """Create CompiledWrapper cell. Args: cell: Instance of `RNNCell`. compile_stateful: Whether to compile stateful ops like initializers and random number generators (default: False). """ self._cell = cell self._compile_stateful = compile_stateful @property def state_size(self): return self._cell.state_size @property def output_size(self): return self._cell.output_size def zero_state(self, batch_size, dtype): with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): return self._cell.zero_state(batch_size, dtype) def __call__(self, inputs, state, scope=None): if self._compile_stateful: compile_ops = True else: def compile_ops(node_def): global _REGISTERED_OPS if _REGISTERED_OPS is None: _REGISTERED_OPS = op_def_registry.get_registered_ops() return not _REGISTERED_OPS[node_def.op].is_stateful with jit.experimental_jit_scope(compile_ops=compile_ops): return self._cell(inputs, state, scope=scope) def _random_exp_initializer(minval, maxval, seed=None, dtype=dtypes.float32): """Returns an exponential distribution initializer. Args: minval: float or a scalar float Tensor. With value > 0. Lower bound of the range of random values to generate. maxval: float or a scalar float Tensor. With value > minval. Upper bound of the range of random values to generate. seed: An integer. Used to create random seeds. dtype: The data type. Returns: An initializer that generates tensors with an exponential distribution. """ def _initializer(shape, dtype=dtype, partition_info=None): del partition_info # Unused. return math_ops.exp( random_ops.random_uniform( shape, math_ops.log(minval), math_ops.log(maxval), dtype, seed=seed)) return _initializer class PhasedLSTMCell(rnn_cell_impl.RNNCell): """Phased LSTM recurrent network cell. https://arxiv.org/pdf/1610.09513v1.pdf """ def __init__(self, num_units, use_peepholes=False, leak=0.001, ratio_on=0.1, trainable_ratio_on=True, period_init_min=1.0, period_init_max=1000.0, reuse=None): """Initialize the Phased LSTM cell. Args: num_units: int, The number of units in the Phased LSTM cell. use_peepholes: bool, set True to enable peephole connections. leak: float or scalar float Tensor with value in [0, 1]. Leak applied during training. ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the period during which the gates are open. trainable_ratio_on: bool, weather ratio_on is trainable. period_init_min: float or scalar float Tensor. With value > 0. Minimum value of the initialized period. The period values are initialized by drawing from the distribution: e^U(log(period_init_min), log(period_init_max)) Where U(.,.) is the uniform distribution. period_init_max: float or scalar float Tensor. With value > period_init_min. Maximum value of the initialized period. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(PhasedLSTMCell, self).__init__(_reuse=reuse) self._num_units = num_units self._use_peepholes = use_peepholes self._leak = leak self._ratio_on = ratio_on self._trainable_ratio_on = trainable_ratio_on self._period_init_min = period_init_min self._period_init_max = period_init_max self._reuse = reuse self._linear1 = None self._linear2 = None self._linear3 = None @property def state_size(self): return rnn_cell_impl.LSTMStateTuple(self._num_units, self._num_units) @property def output_size(self): return self._num_units def _mod(self, x, y): """Modulo function that propagates x gradients.""" return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x def _get_cycle_ratio(self, time, phase, period): """Compute the cycle ratio in the dtype of the time.""" phase_casted = math_ops.cast(phase, dtype=time.dtype) period_casted = math_ops.cast(period, dtype=time.dtype) shifted_time = time - phase_casted cycle_ratio = self._mod(shifted_time, period_casted) / period_casted return math_ops.cast(cycle_ratio, dtype=dtypes.float32) def call(self, inputs, state): """Phased LSTM Cell. Args: inputs: A tuple of 2 Tensor. The first Tensor has shape [batch, 1], and type float32 or float64. It stores the time. The second Tensor has shape [batch, features_size], and type float32. It stores the features. state: rnn_cell_impl.LSTMStateTuple, state from previous timestep. Returns: A tuple containing: - A Tensor of float32, and shape [batch_size, num_units], representing the output of the cell. - A rnn_cell_impl.LSTMStateTuple, containing 2 Tensors of float32, shape [batch_size, num_units], representing the new state and the output. """ (c_prev, h_prev) = state (time, x) = inputs in_mask_gates = [x, h_prev] if self._use_peepholes: in_mask_gates.append(c_prev) with vs.variable_scope("mask_gates"): if self._linear1 is None: self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True) mask_gates = math_ops.sigmoid(self._linear1(in_mask_gates)) [input_gate, forget_gate] = array_ops.split( axis=1, num_or_size_splits=2, value=mask_gates) with vs.variable_scope("new_input"): if self._linear2 is None: self._linear2 = _Linear([x, h_prev], self._num_units, True) new_input = math_ops.tanh(self._linear2([x, h_prev])) new_c = (c_prev * forget_gate + input_gate * new_input) in_out_gate = [x, h_prev] if self._use_peepholes: in_out_gate.append(new_c) with vs.variable_scope("output_gate"): if self._linear3 is None: self._linear3 = _Linear(in_out_gate, self._num_units, True) output_gate = math_ops.sigmoid(self._linear3(in_out_gate)) new_h = math_ops.tanh(new_c) * output_gate period = vs.get_variable( "period", [self._num_units], initializer=_random_exp_initializer(self._period_init_min, self._period_init_max)) phase = vs.get_variable( "phase", [self._num_units], initializer=init_ops.random_uniform_initializer(0., period.initial_value)) ratio_on = vs.get_variable( "ratio_on", [self._num_units], initializer=init_ops.constant_initializer(self._ratio_on), trainable=self._trainable_ratio_on) cycle_ratio = self._get_cycle_ratio(time, phase, period) k_up = 2 * cycle_ratio / ratio_on k_down = 2 - k_up k_closed = self._leak * cycle_ratio k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed) k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k) new_c = k * new_c + (1 - k) * c_prev new_h = k * new_h + (1 - k) * h_prev new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state class ConvLSTMCell(rnn_cell_impl.RNNCell): """Convolutional LSTM recurrent network cell. https://arxiv.org/pdf/1506.04214v1.pdf """ def __init__(self, conv_ndims, input_shape, output_channels, kernel_shape, use_bias=True, skip_connection=False, forget_bias=1.0, initializers=None, name="conv_lstm_cell"): """Construct ConvLSTMCell. Args: conv_ndims: Convolution dimensionality (1, 2 or 3). input_shape: Shape of the input as int tuple, excluding the batch size. output_channels: int, number of output channels of the conv LSTM. kernel_shape: Shape of kernel as in tuple (of size 1,2 or 3). use_bias: Use bias in convolutions. skip_connection: If set to `True`, concatenate the input to the output of the conv LSTM. Default: `False`. forget_bias: Forget bias. name: Name of the module. Raises: ValueError: If `skip_connection` is `True` and stride is different from 1 or if `input_shape` is incompatible with `conv_ndims`. """ super(ConvLSTMCell, self).__init__(name=name) if conv_ndims != len(input_shape) - 1: raise ValueError("Invalid input_shape {} for conv_ndims={}.".format( input_shape, conv_ndims)) self._conv_ndims = conv_ndims self._input_shape = input_shape self._output_channels = output_channels self._kernel_shape = kernel_shape self._use_bias = use_bias self._forget_bias = forget_bias self._skip_connection = skip_connection self._total_output_channels = output_channels if self._skip_connection: self._total_output_channels += self._input_shape[-1] state_size = tensor_shape.TensorShape( self._input_shape[:-1] + [self._output_channels]) self._state_size = rnn_cell_impl.LSTMStateTuple(state_size, state_size) self._output_size = tensor_shape.TensorShape( self._input_shape[:-1] + [self._total_output_channels]) @property def output_size(self): return self._output_size @property def state_size(self): return self._state_size def call(self, inputs, state, scope=None): cell, hidden = state new_hidden = _conv([inputs, hidden], self._kernel_shape, 4 * self._output_channels, self._use_bias) gates = array_ops.split( value=new_hidden, num_or_size_splits=4, axis=self._conv_ndims + 1) input_gate, new_input, forget_gate, output_gate = gates new_cell = math_ops.sigmoid(forget_gate + self._forget_bias) * cell new_cell += math_ops.sigmoid(input_gate) * math_ops.tanh(new_input) output = math_ops.tanh(new_cell) * math_ops.sigmoid(output_gate) if self._skip_connection: output = array_ops.concat([output, inputs], axis=-1) new_state = rnn_cell_impl.LSTMStateTuple(new_cell, output) return output, new_state class Conv1DLSTMCell(ConvLSTMCell): """1D Convolutional LSTM recurrent network cell. https://arxiv.org/pdf/1506.04214v1.pdf """ def __init__(self, name="conv_1d_lstm_cell", **kwargs): """Construct Conv1DLSTM. See `ConvLSTMCell` for more details.""" super(Conv1DLSTMCell, self).__init__(conv_ndims=1, **kwargs) class Conv2DLSTMCell(ConvLSTMCell): """2D Convolutional LSTM recurrent network cell. https://arxiv.org/pdf/1506.04214v1.pdf """ def __init__(self, name="conv_2d_lstm_cell", **kwargs): """Construct Conv2DLSTM. See `ConvLSTMCell` for more details.""" super(Conv2DLSTMCell, self).__init__(conv_ndims=2, **kwargs) class Conv3DLSTMCell(ConvLSTMCell): """3D Convolutional LSTM recurrent network cell. https://arxiv.org/pdf/1506.04214v1.pdf """ def __init__(self, name="conv_3d_lstm_cell", **kwargs): """Construct Conv3DLSTM. See `ConvLSTMCell` for more details.""" super(Conv3DLSTMCell, self).__init__(conv_ndims=3, **kwargs) def _conv(args, filter_size, num_features, bias, bias_start=0.0): """convolution: Args: args: a Tensor or a list of Tensors of dimension 3D, 4D or 5D, batch x n, Tensors. filter_size: int tuple of filter height and width. num_features: int, number of features. bias_start: starting value to initialize the bias; 0 by default. Returns: A 3D, 4D, or 5D Tensor with shape [batch ... num_features] Raises: ValueError: if some of the arguments has unspecified or wrong shape. """ # Calculate the total size of arguments on dimension 1. total_arg_size_depth = 0 shapes = [a.get_shape().as_list() for a in args] shape_length = len(shapes[0]) for shape in shapes: if len(shape) not in [3, 4, 5]: raise ValueError("Conv Linear expects 3D, 4D " "or 5D arguments: %s" % str(shapes)) if len(shape) != len(shapes[0]): raise ValueError("Conv Linear expects all args " "to be of same Dimension: %s" % str(shapes)) else: total_arg_size_depth += shape[-1] dtype = [a.dtype for a in args][0] # determine correct conv operation if shape_length == 3: conv_op = nn_ops.conv1d strides = 1 elif shape_length == 4: conv_op = nn_ops.conv2d strides = shape_length * [1] elif shape_length == 5: conv_op = nn_ops.conv3d strides = shape_length * [1] # Now the computation. kernel = vs.get_variable( "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype) if len(args) == 1: res = conv_op(args[0], kernel, strides, padding="SAME") else: res = conv_op( array_ops.concat(axis=shape_length - 1, values=args), kernel, strides, padding="SAME") if not bias: return res bias_term = vs.get_variable( "biases", [num_features], dtype=dtype, initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) return res + bias_term class GLSTMCell(rnn_cell_impl.RNNCell): """Group LSTM cell (G-LSTM). The implementation is based on: https://arxiv.org/abs/1703.10722 O. Kuchaiev and B. Ginsburg "Factorization Tricks for LSTM Networks", ICLR 2017 workshop. """ def __init__(self, num_units, initializer=None, num_proj=None, number_of_groups=1, forget_bias=1.0, activation=math_ops.tanh, reuse=None): """Initialize the parameters of G-LSTM cell. Args: num_units: int, The number of units in the G-LSTM cell initializer: (optional) The initializer to use for the weight and projection matrices. num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. number_of_groups: (optional) int, number of groups to use. If `number_of_groups` is 1, then it should be equivalent to LSTM cell forget_bias: Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of the training. activation: Activation function of the inner states. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. Raises: ValueError: If `num_units` or `num_proj` is not divisible by `number_of_groups`. """ super(GLSTMCell, self).__init__(_reuse=reuse) self._num_units = num_units self._initializer = initializer self._num_proj = num_proj self._forget_bias = forget_bias self._activation = activation self._number_of_groups = number_of_groups if self._num_units % self._number_of_groups != 0: raise ValueError("num_units must be divisible by number_of_groups") if self._num_proj: if self._num_proj % self._number_of_groups != 0: raise ValueError("num_proj must be divisible by number_of_groups") self._group_shape = [ int(self._num_proj / self._number_of_groups), int(self._num_units / self._number_of_groups) ] else: self._group_shape = [ int(self._num_units / self._number_of_groups), int(self._num_units / self._number_of_groups) ] if num_proj: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) self._output_size = num_proj else: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) self._output_size = num_units self._linear1 = [None] * number_of_groups self._linear2 = None @property def state_size(self): return self._state_size @property def output_size(self): return self._output_size def _get_input_for_group(self, inputs, group_id, group_size): """Slices inputs into groups to prepare for processing by cell's groups Args: inputs: cell input or it's previous state, a Tensor, 2D, [batch x num_units] group_id: group id, a Scalar, for which to prepare input group_size: size of the group Returns: subset of inputs corresponding to group "group_id", a Tensor, 2D, [batch x num_units/number_of_groups] """ return array_ops.slice( input_=inputs, begin=[0, group_id * group_size], size=[self._batch_size, group_size], name=("GLSTM_group%d_input_generation" % group_id)) def call(self, inputs, state): """Run one step of G-LSTM. Args: inputs: input Tensor, 2D, [batch x num_units]. state: this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the G-LSTM after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - LSTMStateTuple representing the new state of G-LSTM cell after reading `inputs` when the previous state was `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ (c_prev, m_prev) = state self._batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0] dtype = inputs.dtype scope = vs.get_variable_scope() with vs.variable_scope(scope, initializer=self._initializer): i_parts = [] j_parts = [] f_parts = [] o_parts = [] for group_id in range(self._number_of_groups): with vs.variable_scope("group%d" % group_id): x_g_id = array_ops.concat( [ self._get_input_for_group(inputs, group_id, self._group_shape[0]), self._get_input_for_group(m_prev, group_id, self._group_shape[0]) ], axis=1) linear = self._linear1[group_id] if linear is None: linear = _Linear(x_g_id, 4 * self._group_shape[1], False) self._linear1[group_id] = linear R_k = linear(x_g_id) # pylint: disable=invalid-name i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) i_parts.append(i_k) j_parts.append(j_k) f_parts.append(f_k) o_parts.append(o_k) bi = vs.get_variable( name="bias_i", shape=[self._num_units], dtype=dtype, initializer=init_ops.constant_initializer(0.0, dtype=dtype)) bj = vs.get_variable( name="bias_j", shape=[self._num_units], dtype=dtype, initializer=init_ops.constant_initializer(0.0, dtype=dtype)) bf = vs.get_variable( name="bias_f", shape=[self._num_units], dtype=dtype, initializer=init_ops.constant_initializer(0.0, dtype=dtype)) bo = vs.get_variable( name="bias_o", shape=[self._num_units], dtype=dtype, initializer=init_ops.constant_initializer(0.0, dtype=dtype)) i = nn_ops.bias_add(array_ops.concat(i_parts, axis=1), bi) j = nn_ops.bias_add(array_ops.concat(j_parts, axis=1), bj) f = nn_ops.bias_add(array_ops.concat(f_parts, axis=1), bf) o = nn_ops.bias_add(array_ops.concat(o_parts, axis=1), bo) c = ( math_ops.sigmoid(f + self._forget_bias) * c_prev + math_ops.sigmoid(i) * math_ops.tanh(j)) m = math_ops.sigmoid(o) * self._activation(c) if self._num_proj is not None: with vs.variable_scope("projection"): if self._linear2 is None: self._linear2 = _Linear(m, self._num_proj, False) m = self._linear2(m) new_state = rnn_cell_impl.LSTMStateTuple(c, m) return m, new_state class LayerNormLSTMCell(rnn_cell_impl.RNNCell): """Long short-term memory unit (LSTM) recurrent network cell. The default non-peephole implementation is based on: http://www.bioinf.jku.at/publications/older/2604.pdf S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. The peephole implementation is based on: https://research.google.com/pubs/archive/43905.pdf Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. The class uses optional peep-hole connections, optional cell clipping, and an optional projection layer. Layer normalization implementation is based on: https://arxiv.org/abs/1607.06450. "Layer Normalization" Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton and is applied before the internal nonlinearities. """ def __init__(self, num_units, use_peepholes=False, cell_clip=None, initializer=None, num_proj=None, proj_clip=None, forget_bias=1.0, activation=None, layer_norm=False, norm_gain=1.0, norm_shift=0.0, reuse=None): """Initialize the parameters for an LSTM cell. Args: num_units: int, The number of units in the LSTM cell use_peepholes: bool, set True to enable diagonal/peephole connections. cell_clip: (optional) A float value, if provided the cell state is clipped by this value prior to the cell output activation. initializer: (optional) The initializer to use for the weight and projection matrices. num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is provided, then the projected values are clipped elementwise to within `[-proj_clip, proj_clip]`. forget_bias: Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of the training. Must set it manually to `0.0` when restoring from CudnnLSTM trained checkpoints. activation: Activation function of the inner states. Default: `tanh`. layer_norm: If `True`, layer normalization will be applied. norm_gain: float, The layer normalization gain initial value. If `layer_norm` has been set to `False`, this argument will be ignored. norm_shift: float, The layer normalization shift initial value. If `layer_norm` has been set to `False`, this argument will be ignored. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. When restoring from CudnnLSTM-trained checkpoints, must use CudnnCompatibleLSTMCell instead. """ super(LayerNormLSTMCell, self).__init__(_reuse=reuse) self._num_units = num_units self._use_peepholes = use_peepholes self._cell_clip = cell_clip self._initializer = initializer self._num_proj = num_proj self._proj_clip = proj_clip self._forget_bias = forget_bias self._activation = activation or math_ops.tanh self._layer_norm = layer_norm self._norm_gain = norm_gain self._norm_shift = norm_shift if num_proj: self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)) self._output_size = num_proj else: self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)) self._output_size = num_units @property def state_size(self): return self._state_size @property def output_size(self): return self._output_size def _linear(self, args, output_size, bias, bias_initializer=None, kernel_initializer=None, layer_norm=False): """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable. Args: args: a 2D Tensor or a list of 2D, batch x n, Tensors. output_size: int, second dimension of W[i]. bias: boolean, whether to add a bias term or not. bias_initializer: starting value to initialize the bias (default is all zeros). kernel_initializer: starting value to initialize the weight. layer_norm: boolean, whether to apply layer normalization. Returns: A 2D Tensor with shape [batch x output_size] taking value sum_i(args[i] * W[i]), where each W[i] is a newly created Variable. Raises: ValueError: if some of the arguments has unspecified or wrong shape. """ if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] # Calculate the total size of arguments on dimension 1. total_arg_size = 0 shapes = [a.get_shape() for a in args] for shape in shapes: if shape.ndims != 2: raise ValueError("linear is expecting 2D arguments: %s" % shapes) if shape[1].value is None: raise ValueError("linear expects shape[1] to be provided for shape %s, " "but saw %s" % (shape, shape[1])) else: total_arg_size += shape[1].value dtype = [a.dtype for a in args][0] # Now the computation. scope = vs.get_variable_scope() with vs.variable_scope(scope) as outer_scope: weights = vs.get_variable( "kernel", [total_arg_size, output_size], dtype=dtype, initializer=kernel_initializer) if len(args) == 1: res = math_ops.matmul(args[0], weights) else: res = math_ops.matmul(array_ops.concat(args, 1), weights) if not bias: return res with vs.variable_scope(outer_scope) as inner_scope: inner_scope.set_partitioner(None) if bias_initializer is None: bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) biases = vs.get_variable( "bias", [output_size], dtype=dtype, initializer=bias_initializer) if not layer_norm: res = nn_ops.bias_add(res, biases) return res def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: this must be a tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the LSTM after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of LSTM after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ sigmoid = math_ops.sigmoid (c_prev, m_prev) = state dtype = inputs.dtype input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") scope = vs.get_variable_scope() with vs.variable_scope(scope, initializer=self._initializer) as unit_scope: # i = input_gate, j = new_input, f = forget_gate, o = output_gate lstm_matrix = self._linear( [inputs, m_prev], 4 * self._num_units, bias=True, bias_initializer=None, layer_norm=self._layer_norm) i, j, f, o = array_ops.split( value=lstm_matrix, num_or_size_splits=4, axis=1) if self._layer_norm: i = _norm(self._norm_gain, self._norm_shift, i, "input") j = _norm(self._norm_gain, self._norm_shift, j, "transform") f = _norm(self._norm_gain, self._norm_shift, f, "forget") o = _norm(self._norm_gain, self._norm_shift, o, "output") # Diagonal connections if self._use_peepholes: with vs.variable_scope(unit_scope): w_f_diag = vs.get_variable( "w_f_diag", shape=[self._num_units], dtype=dtype) w_i_diag = vs.get_variable( "w_i_diag", shape=[self._num_units], dtype=dtype) w_o_diag = vs.get_variable( "w_o_diag", shape=[self._num_units], dtype=dtype) if self._use_peepholes: c = ( sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + sigmoid(i + w_i_diag * c_prev) * self._activation(j)) else: c = ( sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * self._activation(j)) if self._layer_norm: c = _norm(self._norm_gain, self._norm_shift, c, "state") if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type if self._use_peepholes: m = sigmoid(o + w_o_diag * c) * self._activation(c) else: m = sigmoid(o) * self._activation(c) if self._num_proj is not None: with vs.variable_scope("projection"): m = self._linear(m, self._num_proj, bias=False) if self._proj_clip is not None: # pylint: disable=invalid-unary-operand-type m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) # pylint: enable=invalid-unary-operand-type new_state = (rnn_cell_impl.LSTMStateTuple(c, m)) return m, new_state class SRUCell(rnn_cell_impl._LayerRNNCell): """SRU, Simple Recurrent Unit Implementation based on Training RNNs as Fast as CNNs (cf. https://arxiv.org/abs/1709.02755). This variation of RNN cell is characterized by the simplified data dependence between hidden states of two consecutive time steps. Traditionally, hidden states from a cell at time step t-1 needs to be multiplied with a matrix W_hh before being fed into the ensuing cell at time step t. This flavor of RNN replaces the matrix multiplication between h_{t-1} and W_hh with a pointwise multiplication, resulting in performance gain. Args: num_units: int, The number of units in the SRU cell. activation: Nonlinearity to use. Default: `tanh`. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. name: (optional) String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. """ def __init__(self, num_units, activation=None, reuse=None, name=None): super(SRUCell, self).__init__(_reuse=reuse, name=name) self._num_units = num_units self._activation = activation or math_ops.tanh # Restrict inputs to be 2-dimensional matrices self.input_spec = base_layer.InputSpec(ndim=2) @property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def build(self, inputs_shape): if inputs_shape[1].value is None: raise ValueError( "Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) input_depth = inputs_shape[1].value # Here the contributor believes that the following constraints # are implied. The reasoning is explained here with reference to # the paper https://arxiv.org/pdf/1709.02755.pdf upon which this # implementation is based. # In section 2.1 Equation 5, specifically: # h_t = r_t \odot g(c_t) + (1 - r_t) \odot x_t # the pointwise operation between r_t and x_t means they have # the same shape (since we are implementing an RNN cell, braodcasting # does not happen to input of a single timestep); by the same # reasons, x_t has the same shape as h_t, essentially mandating that # input_depth = unit_num. if input_depth != self._num_units: raise ValueError("SRU requires input_depth == num_units, got " "input_depth = %s, num_units = %s" % (input_depth, self._num_units)) self._kernel = self.add_variable( rnn_cell_impl._WEIGHTS_VARIABLE_NAME, shape=[input_depth, 3 * self._num_units]) self._bias = self.add_variable( rnn_cell_impl._BIAS_VARIABLE_NAME, shape=[2 * self._num_units], initializer=init_ops.constant_initializer(0.0, dtype=self.dtype)) self._built = True def call(self, inputs, state): """Simple recurrent unit (SRU) with num_units cells.""" U = math_ops.matmul(inputs, self._kernel) x_bar, f_intermediate, r_intermediate = array_ops.split( value=U, num_or_size_splits=3, axis=1) f_r = math_ops.sigmoid( nn_ops.bias_add( array_ops.concat([f_intermediate, r_intermediate], 1), self._bias)) f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1) c = f * state + (1.0 - f) * x_bar h = r * self._activation(c) + (1.0 - r) * inputs return h, c class WeightNormLSTMCell(rnn_cell_impl.RNNCell): """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`. The weight-norm implementation is based on: https://arxiv.org/abs/1602.07868 Tim Salimans, Diederik P. Kingma. Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks The default LSTM implementation based on: http://www.bioinf.jku.at/publications/older/2604.pdf S. Hochreiter and J. Schmidhuber. "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. The class uses optional peephole connections, optional cell clipping and an optional projection layer. The optional peephole implementation is based on: https://research.google.com/pubs/archive/43905.pdf Hasim Sak, Andrew Senior, and Francoise Beaufays. "Long short-term memory recurrent neural network architectures for large scale acoustic modeling." INTERSPEECH, 2014. """ def __init__(self, num_units, norm=True, use_peepholes=False, cell_clip=None, initializer=None, num_proj=None, proj_clip=None, forget_bias=1, activation=None, reuse=None): """Initialize the parameters of a weight-normalized LSTM cell. Args: num_units: int, The number of units in the LSTM cell norm: If `True`, apply normalization to the weight matrices. If False, the result is identical to that obtained from `rnn_cell_impl.LSTMCell` use_peepholes: bool, set `True` to enable diagonal/peephole connections. cell_clip: (optional) A float value, if provided the cell state is clipped by this value prior to the cell output activation. initializer: (optional) The initializer to use for the weight matrices. num_proj: (optional) int, The output dimensionality for the projection matrices. If None, no projection is performed. proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is provided, then the projected values are clipped elementwise to within `[-proj_clip, proj_clip]`. forget_bias: Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of the training. activation: Activation function of the inner states. Default: `tanh`. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. """ super(WeightNormLSTMCell, self).__init__(_reuse=reuse) self._scope = "wn_lstm_cell" self._num_units = num_units self._norm = norm self._initializer = initializer self._use_peepholes = use_peepholes self._cell_clip = cell_clip self._num_proj = num_proj self._proj_clip = proj_clip self._activation = activation or math_ops.tanh self._forget_bias = forget_bias self._weights_variable_name = "kernel" self._bias_variable_name = "bias" if num_proj: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) self._output_size = num_proj else: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) self._output_size = num_units @property def state_size(self): return self._state_size @property def output_size(self): return self._output_size def _normalize(self, weight, name): """Apply weight normalization. Args: weight: a 2D tensor with known number of columns. name: string, variable name for the normalizer. Returns: A tensor with the same shape as `weight`. """ output_size = weight.get_shape().as_list()[1] g = vs.get_variable(name, [output_size], dtype=weight.dtype) return nn_impl.l2_normalize(weight, dim=0) * g def _linear(self, args, output_size, norm, bias, bias_initializer=None, kernel_initializer=None): """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. Args: args: a 2D Tensor or a list of 2D, batch x n, Tensors. output_size: int, second dimension of W[i]. bias: boolean, whether to add a bias term or not. bias_initializer: starting value to initialize the bias (default is all zeros). kernel_initializer: starting value to initialize the weight. Returns: A 2D Tensor with shape [batch x output_size] equal to sum_i(args[i] * W[i]), where W[i]s are newly created matrices. Raises: ValueError: if some of the arguments has unspecified or wrong shape. """ if args is None or (nest.is_sequence(args) and not args): raise ValueError("`args` must be specified") if not nest.is_sequence(args): args = [args] # Calculate the total size of arguments on dimension 1. total_arg_size = 0 shapes = [a.get_shape() for a in args] for shape in shapes: if shape.ndims != 2: raise ValueError("linear is expecting 2D arguments: %s" % shapes) if shape[1].value is None: raise ValueError("linear expects shape[1] to be provided for shape %s, " "but saw %s" % (shape, shape[1])) else: total_arg_size += shape[1].value dtype = [a.dtype for a in args][0] # Now the computation. scope = vs.get_variable_scope() with vs.variable_scope(scope) as outer_scope: weights = vs.get_variable( self._weights_variable_name, [total_arg_size, output_size], dtype=dtype, initializer=kernel_initializer) if norm: wn = [] st = 0 with ops.control_dependencies(None): for i in range(len(args)): en = st + shapes[i][1].value wn.append( self._normalize(weights[st:en, :], name="norm_{}".format(i))) st = en weights = array_ops.concat(wn, axis=0) if len(args) == 1: res = math_ops.matmul(args[0], weights) else: res = math_ops.matmul(array_ops.concat(args, 1), weights) if not bias: return res with vs.variable_scope(outer_scope) as inner_scope: inner_scope.set_partitioner(None) if bias_initializer is None: bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) biases = vs.get_variable( self._bias_variable_name, [output_size], dtype=dtype, initializer=bias_initializer) return nn_ops.bias_add(res, biases) def call(self, inputs, state): """Run one step of LSTM. Args: inputs: input Tensor, 2D, batch x num_units. state: A tuple of state Tensors, both `2-D`, with column sizes `c_state` and `m_state`. Returns: A tuple containing: - A `2-D, [batch x output_dim]`, Tensor representing the output of the LSTM after reading `inputs` when previous state was `state`. Here output_dim is: num_proj if num_proj was set, num_units otherwise. - Tensor(s) representing the new state of LSTM after reading `inputs` when the previous state was `state`. Same type and shape(s) as `state`. Raises: ValueError: If input size cannot be inferred from inputs via static shape inference. """ dtype = inputs.dtype num_units = self._num_units sigmoid = math_ops.sigmoid c, h = state input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") with vs.variable_scope(self._scope, initializer=self._initializer): concat = self._linear( [inputs, h], 4 * num_units, norm=self._norm, bias=True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) if self._use_peepholes: w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype) w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype) w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype) new_c = ( c * sigmoid(f + self._forget_bias + w_f_diag * c) + sigmoid(i + w_i_diag * c) * self._activation(j)) else: new_c = ( c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type if self._use_peepholes: new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c) else: new_h = sigmoid(o) * self._activation(new_c) if self._num_proj is not None: with vs.variable_scope("projection"): new_h = self._linear( new_h, self._num_proj, norm=self._norm, bias=False) if self._proj_clip is not None: # pylint: disable=invalid-unary-operand-type new_h = clip_ops.clip_by_value(new_h, -self._proj_clip, self._proj_clip) # pylint: enable=invalid-unary-operand-type new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state