aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py')
-rw-r--r--tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py201
1 files changed, 71 insertions, 130 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
index c2843edaf2..2d65d956a8 100644
--- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
+++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import collections
import math
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -62,7 +61,7 @@ class BasicRNNCell(RNNCell):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
with vs.variable_scope(scope or "basic_rnn_cell"):
output = self._activation(
- _linear([inputs, state], self._num_units, True))
+ _linear([inputs, state], self._num_units, True, scope=scope))
return output, output
@@ -90,13 +89,14 @@ class GRUCell(RNNCell):
# We start with bias of 1.0 to not reset and not update.
r, u = array_ops.split(
value=_linear(
- [inputs, state], 2 * self._num_units, True, 1.0),
+ [inputs, state], 2 * self._num_units, True, 1.0, scope=scope),
num_or_size_splits=2,
axis=1)
r, u = sigmoid(r), sigmoid(u)
with vs.variable_scope("candidate"):
c = self._activation(_linear([inputs, r * state],
- self._num_units, True))
+ self._num_units, True,
+ scope=scope))
new_h = u * state + (1 - u) * c
return new_h, new_h
@@ -176,7 +176,7 @@ class BasicLSTMCell(RNNCell):
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
- concat = _linear([inputs, h], 4 * self._num_units, True)
+ concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
@@ -192,13 +192,6 @@ class BasicLSTMCell(RNNCell):
return new_h, new_state
-def _maybe_compile(fun, compiled):
- if not compiled:
- return fun
- else:
- return function.Defun(noinline=True, compiled=True)(fun)
-
-
class LSTMCell(RNNCell):
"""Long short-term memory unit (LSTM) recurrent network cell.
@@ -226,7 +219,7 @@ class LSTMCell(RNNCell):
initializer=None, num_proj=None, proj_clip=None,
num_unit_shards=None, num_proj_shards=None,
forget_bias=1.0, state_is_tuple=True,
- activation=tanh, compiled=False):
+ activation=tanh):
"""Initialize the parameters for an LSTM cell.
Args:
@@ -253,12 +246,6 @@ class LSTMCell(RNNCell):
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. This latter behavior will soon be deprecated.
activation: Activation function of the inner states.
- compiled: Python boolean. If `True`, the core computation of the LSTM
- cell is compiled via XLA. As of now, this provides speedups for
- most GPU calculations, and on small batch CPU and embedded calculations.
-
- Raises:
- ValueError: if compiled=True and state_is_tuple=False (not supported).
"""
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
@@ -270,9 +257,6 @@ class LSTMCell(RNNCell):
"%s: The num_unit_shards and proj_unit_shards parameters are "
"deprecated and will be removed in Jan 2017. "
"Use a variable scope with a partitioner instead.", self)
- if not state_is_tuple and compiled:
- raise ValueError(
- "Combining state_is_tuple=False and compiled=True is not supported.")
self._num_units = num_units
self._use_peepholes = use_peepholes
@@ -285,7 +269,6 @@ class LSTMCell(RNNCell):
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation
- self._compiled = compiled
if num_proj:
self._state_size = (
@@ -334,111 +317,73 @@ class LSTMCell(RNNCell):
"""
num_proj = self._num_units if self._num_proj is None else self._num_proj
- def _kernel(k_inputs, state_p0, state_p1):
- """Internal kernel for the single step of LSTM.
-
- Args:
- k_inputs: Input Tensor.
- state_p0: Either the state or the c component of the state.
- state_p1: Either the state or the m component of the state.
-
- Returns:
- (m, c) or (m, concat([c, m])) depending on state_is_tuple.
-
- Raises:
- ValueError: see above docstring.
- """
- k_inputs.set_shape(inputs.get_shape())
- if self._state_is_tuple:
- (c_prev, m_prev) = state_p0, state_p1
- c_prev.set_shape(state[0].get_shape())
- m_prev.set_shape(state[1].get_shape())
+ if self._state_is_tuple:
+ (c_prev, m_prev) = state
+ else:
+ c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
+ m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
+
+ dtype = inputs.dtype
+ input_size = inputs.get_shape().with_rank(2)[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+ with vs.variable_scope(scope or "lstm_cell",
+ initializer=self._initializer) as unit_scope:
+ if self._num_unit_shards is not None:
+ unit_scope.set_partitioner(
+ partitioned_variables.fixed_size_partitioner(
+ self._num_unit_shards))
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True,
+ scope=scope)
+ i, j, f, o = array_ops.split(
+ value=lstm_matrix, num_or_size_splits=4, axis=1)
+
+ # Diagonal connections
+ if self._use_peepholes:
+ with vs.variable_scope(unit_scope) as projection_scope:
+ if self._num_unit_shards is not None:
+ projection_scope.set_partitioner(None)
+ w_f_diag = vs.get_variable(
+ "w_f_diag", shape=[self._num_units], dtype=dtype)
+ w_i_diag = vs.get_variable(
+ "w_i_diag", shape=[self._num_units], dtype=dtype)
+ w_o_diag = vs.get_variable(
+ "w_o_diag", shape=[self._num_units], dtype=dtype)
+
+ if self._use_peepholes:
+ c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
+ sigmoid(i + w_i_diag * c_prev) * self._activation(j))
else:
- k_state = state_p0
- c_prev = array_ops.slice(k_state, [0, 0], [-1, self._num_units])
- m_prev = array_ops.slice(k_state, [0, self._num_units], [-1, num_proj])
-
- dtype = k_inputs.dtype
- input_size = k_inputs.get_shape().with_rank(2)[1]
- if input_size.value is None:
- raise ValueError(
- "Could not infer input size from inputs.get_shape()[-1]")
- with vs.variable_scope(scope or "lstm_cell",
- initializer=self._initializer) as unit_scope:
- if self._num_unit_shards is not None:
- unit_scope.set_partitioner(
- partitioned_variables.fixed_size_partitioner(
- self._num_unit_shards))
- # i = input_gate, j = new_input, f = forget_gate, o = output_gate
- lstm_matrix = _linear(
- [k_inputs, m_prev], 4 * self._num_units, bias=True,
- compiled=self._compiled)
- i, j, f, o = array_ops.split(
- value=lstm_matrix, num_or_size_splits=4, axis=1)
-
- # Diagonal connections
- if self._use_peepholes:
- with vs.variable_scope(unit_scope) as projection_scope:
- if self._num_unit_shards is not None:
- projection_scope.set_partitioner(None)
- w_f_diag = vs.get_variable(
- "w_f_diag", shape=[self._num_units], dtype=dtype)
- w_i_diag = vs.get_variable(
- "w_i_diag", shape=[self._num_units], dtype=dtype)
- w_o_diag = vs.get_variable(
- "w_o_diag", shape=[self._num_units], dtype=dtype)
- c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
- sigmoid(i + w_i_diag * c_prev) * self._activation(j))
- else:
- c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
- self._activation(j))
-
- if self._cell_clip is not None:
- # pylint: disable=invalid-unary-operand-type
- c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
- # pylint: enable=invalid-unary-operand-type
-
- if self._use_peepholes:
- m = sigmoid(o + w_o_diag * c) * self._activation(c)
- else:
- m = sigmoid(o) * self._activation(c)
+ c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
+ self._activation(j))
- if self._num_proj is not None:
- with vs.variable_scope("projection") as proj_scope:
- if self._num_proj_shards is not None:
- proj_scope.set_partitioner(
- partitioned_variables.fixed_size_partitioner(
- self._num_proj_shards))
- m = _linear(m, self._num_proj, bias=False, compiled=self._compiled)
+ if self._cell_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
+ # pylint: enable=invalid-unary-operand-type
- if self._proj_clip is not None:
- # pylint: disable=invalid-unary-operand-type
- m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
- # pylint: enable=invalid-unary-operand-type
-
- if self._state_is_tuple:
- return m, c
+ if self._use_peepholes:
+ m = sigmoid(o + w_o_diag * c) * self._activation(c)
else:
- return m, array_ops.concat([c, m], 1)
+ m = sigmoid(o) * self._activation(c)
- compiled_kernel = _maybe_compile(_kernel, self._compiled)
+ if self._num_proj is not None:
+ with vs.variable_scope("projection") as proj_scope:
+ if self._num_proj_shards is not None:
+ proj_scope.set_partitioner(
+ partitioned_variables.fixed_size_partitioner(
+ self._num_proj_shards))
+ m = _linear(m, self._num_proj, bias=False, scope=scope)
- if self._state_is_tuple:
- batch_shape = (
- inputs.get_shape()[:1].merge_with(
- state[0].get_shape()[:1]).merge_with(
- state[1].get_shape()[:1]))
- emit_m, emit_c = compiled_kernel(inputs, state[0], state[1])
- emit_c.set_shape(batch_shape.concatenate([state[0].get_shape()[1]]))
- emit_m.set_shape(batch_shape.concatenate([state[1].get_shape()[1]]))
- emit_state = LSTMStateTuple(emit_c, emit_m)
- else:
- batch_shape = inputs.get_shape()[:1].merge_with(state.get_shape()[:1])
- emit_m, emit_state = compiled_kernel(inputs, state, state)
- emit_m.set_shape(batch_shape.concatenate([num_proj]))
- emit_state.set_shape(batch_shape.concatenate([state.get_shape()[1]]))
+ if self._proj_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
+ # pylint: enable=invalid-unary-operand-type
- return emit_m, emit_state
+ new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
+ array_ops.concat([c, m], 1))
+ return m, new_state
class OutputProjectionWrapper(RNNCell):
@@ -481,7 +426,7 @@ class OutputProjectionWrapper(RNNCell):
output, res_state = self._cell(inputs, state)
# Default scope: "OutputProjectionWrapper"
with vs.variable_scope(scope or "output_projection_wrapper"):
- projected = _linear(output, self._output_size, True)
+ projected = _linear(output, self._output_size, True, scope=scope)
return projected, res_state
@@ -523,7 +468,7 @@ class InputProjectionWrapper(RNNCell):
"""Run the input projection and then the cell."""
# Default scope: "InputProjectionWrapper"
with vs.variable_scope(scope or "input_projection_wrapper"):
- projected = _linear(inputs, self._num_proj, True)
+ projected = _linear(inputs, self._num_proj, True, scope=scope)
return self._cell(projected, state)
@@ -817,7 +762,7 @@ class _SlimRNNCell(RNNCell):
return output, state
-def _linear(args, output_size, bias, bias_start=0.0, compiled=False):
+def _linear(args, output_size, bias, bias_start=0.0, scope=None):
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
Args:
@@ -825,7 +770,7 @@ def _linear(args, output_size, bias, bias_start=0.0, compiled=False):
output_size: int, second dimension of W[i].
bias: boolean, whether to add a bias term or not.
bias_start: starting value to initialize the bias; 0 by default.
- compiled: boolean, _linear plays nicely with XLA if it is enabled.
+ scope: (optional) Variable scope to create parameters in.
Returns:
A 2D Tensor with shape [batch x output_size] equal to
@@ -870,8 +815,4 @@ def _linear(args, output_size, bias, bias_start=0.0, compiled=False):
"biases", [output_size],
dtype=dtype,
initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
- if compiled:
- # TODO(b/34505635): Defuns don't play well with bias_add
- return res + biases
- else:
- return nn_ops.bias_add(res, biases)
+ return nn_ops.bias_add(res, biases)