aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-11-13 12:59:04 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-13 13:03:41 -0800
commit90222dd7b29ff2597bc7f8d0f92db17324f591b0 (patch)
tree0e9b5f55057b908824d2c3ce8c8e8e176702a2d9 /tensorflow/contrib/cudnn_rnn
parent333bdea9524dd2bacf626051dbdbbcfcc4b46122 (diff)
Fix CuDNNCompatibleGRU after GRUCell refactorization
PiperOrigin-RevId: 175574730
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py93
1 files changed, 65 insertions, 28 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 9f74899693..6c526b2c75 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -18,7 +18,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops
-from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.contrib.util import loader
from tensorflow.python.framework import common_shapes
@@ -29,6 +28,7 @@ from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope as vs
@@ -55,6 +55,11 @@ CUDNN_INPUT_LINEAR_MODE = "linear_input"
CUDNN_INPUT_SKIP_MODE = "skip_input"
CUDNN_INPUT_AUTO_MODE = "auto_select"
+# pylint:disable=protected-access
+_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME
+_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME
+# pylint:enable=protected-access
+
class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell):
"""Cudnn Compatible LSTMCell.
@@ -87,9 +92,9 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell):
Cudnn compatible GRU (from Cudnn library user guide):
```python
r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate
- i_t = sigma(x_t * W_i + h_t-1 * R_i + b_Wi + b_Ru) # update gate
+ u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate
h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate
- h_t = (1 - i_t) .* h'_t + i_t .* h_t-1
+ h_t = (1 - u_t) .* h'_t + u_t .* h_t-1
```
Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}):
@@ -112,33 +117,65 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell):
reuse=reuse,
kernel_initializer=kernel_initializer)
+ def build(self, inputs_shape):
+ if inputs_shape[1].value is None:
+ raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
+ % inputs_shape)
+
+ input_depth = inputs_shape[1].value
+ self._gate_kernel = self.add_variable(
+ "gates/%s" % _WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth + self._num_units, 2 * self._num_units],
+ initializer=self._kernel_initializer)
+ self._gate_bias = self.add_variable(
+ "gates/%s" % _BIAS_VARIABLE_NAME,
+ shape=[2 * self._num_units],
+ initializer=(
+ self._bias_initializer
+ if self._bias_initializer is not None
+ else init_ops.constant_initializer(1.0, dtype=self.dtype)))
+
+ self._candidate_input_kernel = self.add_variable(
+ "candidate/input_projection/%s" % _WEIGHTS_VARIABLE_NAME,
+ shape=[input_depth, self._num_units],
+ initializer=self._kernel_initializer)
+ self._candidate_hidden_kernel = self.add_variable(
+ "candidate/hidden_projection/%s" % _WEIGHTS_VARIABLE_NAME,
+ shape=[self._num_units, self._num_units],
+ initializer=self._kernel_initializer)
+
+ self._candidate_input_bias = self.add_variable(
+ "candidate/input_projection/%s" % _BIAS_VARIABLE_NAME,
+ shape=[self._num_units],
+ initializer=(
+ self._bias_initializer
+ if self._bias_initializer is not None
+ else init_ops.zeros_initializer(dtype=self.dtype)))
+ self._candidate_hidden_bias = self.add_variable(
+ "candidate/hidden_projection/%s" % _BIAS_VARIABLE_NAME,
+ shape=[self._num_units],
+ initializer=(
+ self._bias_initializer
+ if self._bias_initializer is not None
+ else init_ops.zeros_initializer(dtype=self.dtype)))
+
def call(self, inputs, state):
"""Gated recurrent unit (GRU) with nunits cells."""
- with vs.variable_scope("gates"): # Reset gate and update gate.
- # We start with bias of 1.0 to not reset and not update.
- bias_ones = self._bias_initializer
- if self._bias_initializer is None:
- dtype = inputs.dtype
- bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
- # pylint: disable=protected-access
- value = math_ops.sigmoid(
- core_rnn_cell._linear([inputs, state], 2 * self._num_units, True,
- bias_ones, self._kernel_initializer))
- r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
- # pylint: enable=protected-access
- with vs.variable_scope("candidate"):
- # pylint: disable=protected-access
- with vs.variable_scope("input_projection"):
- hi = core_rnn_cell._linear(inputs, self._num_units, True,
- self._bias_initializer,
- self._kernel_initializer)
- with vs.variable_scope("hidden_projection"):
- hh = r * (core_rnn_cell._linear(state, self._num_units, True,
- self._bias_initializer,
- self._kernel_initializer))
- # pylint: enable=protected-access
- c = self._activation(hi + hh)
- new_h = u * state + (1 - u) * c
+ gate_inputs = math_ops.matmul(
+ array_ops.concat([inputs, state], 1), self._gate_kernel)
+ gate_inputs = nn_ops.bias_add(gate_inputs, self._gate_bias)
+
+ value = math_ops.sigmoid(gate_inputs)
+ r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
+
+ candidate = nn_ops.bias_add(
+ math_ops.matmul(inputs, self._candidate_input_kernel),
+ self._candidate_input_bias)
+ candidate += r * nn_ops.bias_add(
+ math_ops.matmul(state, self._candidate_hidden_kernel),
+ self._candidate_hidden_bias)
+ candidate = self._activation(candidate)
+ new_h = (1-u) * candidate + u * state
return new_h, new_h