diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py | 201 |
1 files changed, 71 insertions, 130 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index c2843edaf2..2d65d956a8 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -22,7 +22,6 @@ from __future__ import print_function import collections import math -from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -62,7 +61,7 @@ class BasicRNNCell(RNNCell): """Most basic RNN: output = new_state = act(W * input + U * state + B).""" with vs.variable_scope(scope or "basic_rnn_cell"): output = self._activation( - _linear([inputs, state], self._num_units, True)) + _linear([inputs, state], self._num_units, True, scope=scope)) return output, output @@ -90,13 +89,14 @@ class GRUCell(RNNCell): # We start with bias of 1.0 to not reset and not update. r, u = array_ops.split( value=_linear( - [inputs, state], 2 * self._num_units, True, 1.0), + [inputs, state], 2 * self._num_units, True, 1.0, scope=scope), num_or_size_splits=2, axis=1) r, u = sigmoid(r), sigmoid(u) with vs.variable_scope("candidate"): c = self._activation(_linear([inputs, r * state], - self._num_units, True)) + self._num_units, True, + scope=scope)) new_h = u * state + (1 - u) * c return new_h, new_h @@ -176,7 +176,7 @@ class BasicLSTMCell(RNNCell): c, h = state else: c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) - concat = _linear([inputs, h], 4 * self._num_units, True) + concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) @@ -192,13 +192,6 @@ class BasicLSTMCell(RNNCell): return new_h, new_state -def _maybe_compile(fun, compiled): - if not compiled: - return fun - else: - return function.Defun(noinline=True, compiled=True)(fun) - - class LSTMCell(RNNCell): """Long short-term memory unit (LSTM) recurrent network cell. @@ -226,7 +219,7 @@ class LSTMCell(RNNCell): initializer=None, num_proj=None, proj_clip=None, num_unit_shards=None, num_proj_shards=None, forget_bias=1.0, state_is_tuple=True, - activation=tanh, compiled=False): + activation=tanh): """Initialize the parameters for an LSTM cell. Args: @@ -253,12 +246,6 @@ class LSTMCell(RNNCell): the `c_state` and `m_state`. If False, they are concatenated along the column axis. This latter behavior will soon be deprecated. activation: Activation function of the inner states. - compiled: Python boolean. If `True`, the core computation of the LSTM - cell is compiled via XLA. As of now, this provides speedups for - most GPU calculations, and on small batch CPU and embedded calculations. - - Raises: - ValueError: if compiled=True and state_is_tuple=False (not supported). """ if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " @@ -270,9 +257,6 @@ class LSTMCell(RNNCell): "%s: The num_unit_shards and proj_unit_shards parameters are " "deprecated and will be removed in Jan 2017. " "Use a variable scope with a partitioner instead.", self) - if not state_is_tuple and compiled: - raise ValueError( - "Combining state_is_tuple=False and compiled=True is not supported.") self._num_units = num_units self._use_peepholes = use_peepholes @@ -285,7 +269,6 @@ class LSTMCell(RNNCell): self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation - self._compiled = compiled if num_proj: self._state_size = ( @@ -334,111 +317,73 @@ class LSTMCell(RNNCell): """ num_proj = self._num_units if self._num_proj is None else self._num_proj - def _kernel(k_inputs, state_p0, state_p1): - """Internal kernel for the single step of LSTM. - - Args: - k_inputs: Input Tensor. - state_p0: Either the state or the c component of the state. - state_p1: Either the state or the m component of the state. - - Returns: - (m, c) or (m, concat([c, m])) depending on state_is_tuple. - - Raises: - ValueError: see above docstring. - """ - k_inputs.set_shape(inputs.get_shape()) - if self._state_is_tuple: - (c_prev, m_prev) = state_p0, state_p1 - c_prev.set_shape(state[0].get_shape()) - m_prev.set_shape(state[1].get_shape()) + if self._state_is_tuple: + (c_prev, m_prev) = state + else: + c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + + dtype = inputs.dtype + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + with vs.variable_scope(scope or "lstm_cell", + initializer=self._initializer) as unit_scope: + if self._num_unit_shards is not None: + unit_scope.set_partitioner( + partitioned_variables.fixed_size_partitioner( + self._num_unit_shards)) + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True, + scope=scope) + i, j, f, o = array_ops.split( + value=lstm_matrix, num_or_size_splits=4, axis=1) + + # Diagonal connections + if self._use_peepholes: + with vs.variable_scope(unit_scope) as projection_scope: + if self._num_unit_shards is not None: + projection_scope.set_partitioner(None) + w_f_diag = vs.get_variable( + "w_f_diag", shape=[self._num_units], dtype=dtype) + w_i_diag = vs.get_variable( + "w_i_diag", shape=[self._num_units], dtype=dtype) + w_o_diag = vs.get_variable( + "w_o_diag", shape=[self._num_units], dtype=dtype) + + if self._use_peepholes: + c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + + sigmoid(i + w_i_diag * c_prev) * self._activation(j)) else: - k_state = state_p0 - c_prev = array_ops.slice(k_state, [0, 0], [-1, self._num_units]) - m_prev = array_ops.slice(k_state, [0, self._num_units], [-1, num_proj]) - - dtype = k_inputs.dtype - input_size = k_inputs.get_shape().with_rank(2)[1] - if input_size.value is None: - raise ValueError( - "Could not infer input size from inputs.get_shape()[-1]") - with vs.variable_scope(scope or "lstm_cell", - initializer=self._initializer) as unit_scope: - if self._num_unit_shards is not None: - unit_scope.set_partitioner( - partitioned_variables.fixed_size_partitioner( - self._num_unit_shards)) - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - lstm_matrix = _linear( - [k_inputs, m_prev], 4 * self._num_units, bias=True, - compiled=self._compiled) - i, j, f, o = array_ops.split( - value=lstm_matrix, num_or_size_splits=4, axis=1) - - # Diagonal connections - if self._use_peepholes: - with vs.variable_scope(unit_scope) as projection_scope: - if self._num_unit_shards is not None: - projection_scope.set_partitioner(None) - w_f_diag = vs.get_variable( - "w_f_diag", shape=[self._num_units], dtype=dtype) - w_i_diag = vs.get_variable( - "w_i_diag", shape=[self._num_units], dtype=dtype) - w_o_diag = vs.get_variable( - "w_o_diag", shape=[self._num_units], dtype=dtype) - c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + - sigmoid(i + w_i_diag * c_prev) * self._activation(j)) - else: - c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * - self._activation(j)) - - if self._cell_clip is not None: - # pylint: disable=invalid-unary-operand-type - c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) - # pylint: enable=invalid-unary-operand-type - - if self._use_peepholes: - m = sigmoid(o + w_o_diag * c) * self._activation(c) - else: - m = sigmoid(o) * self._activation(c) + c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * + self._activation(j)) - if self._num_proj is not None: - with vs.variable_scope("projection") as proj_scope: - if self._num_proj_shards is not None: - proj_scope.set_partitioner( - partitioned_variables.fixed_size_partitioner( - self._num_proj_shards)) - m = _linear(m, self._num_proj, bias=False, compiled=self._compiled) + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type - if self._proj_clip is not None: - # pylint: disable=invalid-unary-operand-type - m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) - # pylint: enable=invalid-unary-operand-type - - if self._state_is_tuple: - return m, c + if self._use_peepholes: + m = sigmoid(o + w_o_diag * c) * self._activation(c) else: - return m, array_ops.concat([c, m], 1) + m = sigmoid(o) * self._activation(c) - compiled_kernel = _maybe_compile(_kernel, self._compiled) + if self._num_proj is not None: + with vs.variable_scope("projection") as proj_scope: + if self._num_proj_shards is not None: + proj_scope.set_partitioner( + partitioned_variables.fixed_size_partitioner( + self._num_proj_shards)) + m = _linear(m, self._num_proj, bias=False, scope=scope) - if self._state_is_tuple: - batch_shape = ( - inputs.get_shape()[:1].merge_with( - state[0].get_shape()[:1]).merge_with( - state[1].get_shape()[:1])) - emit_m, emit_c = compiled_kernel(inputs, state[0], state[1]) - emit_c.set_shape(batch_shape.concatenate([state[0].get_shape()[1]])) - emit_m.set_shape(batch_shape.concatenate([state[1].get_shape()[1]])) - emit_state = LSTMStateTuple(emit_c, emit_m) - else: - batch_shape = inputs.get_shape()[:1].merge_with(state.get_shape()[:1]) - emit_m, emit_state = compiled_kernel(inputs, state, state) - emit_m.set_shape(batch_shape.concatenate([num_proj])) - emit_state.set_shape(batch_shape.concatenate([state.get_shape()[1]])) + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type - return emit_m, emit_state + new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else + array_ops.concat([c, m], 1)) + return m, new_state class OutputProjectionWrapper(RNNCell): @@ -481,7 +426,7 @@ class OutputProjectionWrapper(RNNCell): output, res_state = self._cell(inputs, state) # Default scope: "OutputProjectionWrapper" with vs.variable_scope(scope or "output_projection_wrapper"): - projected = _linear(output, self._output_size, True) + projected = _linear(output, self._output_size, True, scope=scope) return projected, res_state @@ -523,7 +468,7 @@ class InputProjectionWrapper(RNNCell): """Run the input projection and then the cell.""" # Default scope: "InputProjectionWrapper" with vs.variable_scope(scope or "input_projection_wrapper"): - projected = _linear(inputs, self._num_proj, True) + projected = _linear(inputs, self._num_proj, True, scope=scope) return self._cell(projected, state) @@ -817,7 +762,7 @@ class _SlimRNNCell(RNNCell): return output, state -def _linear(args, output_size, bias, bias_start=0.0, compiled=False): +def _linear(args, output_size, bias, bias_start=0.0, scope=None): """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. Args: @@ -825,7 +770,7 @@ def _linear(args, output_size, bias, bias_start=0.0, compiled=False): output_size: int, second dimension of W[i]. bias: boolean, whether to add a bias term or not. bias_start: starting value to initialize the bias; 0 by default. - compiled: boolean, _linear plays nicely with XLA if it is enabled. + scope: (optional) Variable scope to create parameters in. Returns: A 2D Tensor with shape [batch x output_size] equal to @@ -870,8 +815,4 @@ def _linear(args, output_size, bias, bias_start=0.0, compiled=False): "biases", [output_size], dtype=dtype, initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) - if compiled: - # TODO(b/34505635): Defuns don't play well with bias_add - return res + biases - else: - return nn_ops.bias_add(res, biases) + return nn_ops.bias_add(res, biases) |