aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py2
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py56
2 files changed, 31 insertions, 27 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index aed16b4dba..c840ce5227 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -859,7 +859,7 @@ class RNNCellTest(test.TestCase):
[1.844123e-05, -2.159617e-05]],
dtype=np.float32)
with variable_scope.variable_scope("root"):
- t = array_ops.zeros([batch_size, 1])
+ t = array_ops.zeros([batch_size, 1], dtype=dtypes.float64)
x = array_ops.zeros([batch_size, input_size])
c0 = array_ops.zeros([batch_size, 2])
h0 = array_ops.zeros([batch_size, 2])
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 1632946ea6..2cd1814213 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -1727,8 +1727,8 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
num_units,
use_peepholes=False,
leak=0.001,
- r_on=0.1,
- trainable_r_on=True,
+ ratio_on=0.1,
+ trainable_ratio_on=True,
period_init_min=1.0,
period_init_max=1000.0,
reuse=None):
@@ -1739,9 +1739,9 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
use_peepholes: bool, set True to enable peephole connections.
leak: float or scalar float Tensor with value in [0, 1]. Leak applied
during training.
- r_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
+ ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
period during which the gates are open.
- trainable_r_on: bool, weather r_on is trainable.
+ trainable_ratio_on: bool, weather ratio_on is trainable.
period_init_min: float or scalar float Tensor. With value > 0.
Minimum value of the initalized period.
The period values are initialized by drawing from the distribution:
@@ -1756,8 +1756,8 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
self._num_units = num_units
self._use_peepholes = use_peepholes
self._leak = leak
- self._r_on = r_on
- self._trainable_r_on = trainable_r_on
+ self._ratio_on = ratio_on
+ self._trainable_ratio_on = trainable_ratio_on
self._period_init_min = period_init_min
self._period_init_max = period_init_max
self._reuse = reuse
@@ -1774,14 +1774,23 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
"""Modulo function that propagates x gradients."""
return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x
+ def _get_cycle_ratio(self, time, phase, period):
+ """Compute the cycle ratio in the dtype of the time."""
+ phase_casted = math_ops.cast(phase, dtype=time.dtype)
+ period_casted = math_ops.cast(period, dtype=time.dtype)
+ shifted_time = time - phase_casted
+ cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
+ return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
+
def __call__(self, inputs, state, scope=None):
"""Phased LSTM Cell.
Args:
- inputs: A tuple of 2 Tensor of type float32.
- The first Tensor has shape [batch, 1], and stores the time.
- The second Tesnsor has shape [batch, features_size], and stores the
- features.
+ inputs: A tuple of 2 Tensor.
+ The first Tensor has shape [batch, 1], and type float32 or float64.
+ It stores the time.
+ The second Tensor has shape [batch, features_size], and type float32.
+ It stores the features.
state: core_rnn_cell.LSTMStateTuple, state from previous timestep.
scope: string, id of the variable scope.
@@ -1795,7 +1804,6 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
with _checked_scope(self, scope or "phased_lstm_cell", reuse=self._reuse):
(c_prev, h_prev) = state
(time, x) = inputs
- dtype = x.dtype
in_mask_gates = [x, h_prev]
if self._use_peepholes:
@@ -1826,28 +1834,24 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
period = vs.get_variable(
"period", [self._num_units],
initializer=_random_exp_initializer(
- self._period_init_min, self._period_init_max),
- dtype=dtype)
+ self._period_init_min, self._period_init_max))
phase = vs.get_variable(
"phase", [self._num_units],
initializer=init_ops.random_uniform_initializer(
- 0., period.initial_value),
- dtype=dtype)
- r_on = vs.get_variable(
- "r_on", [self._num_units],
- initializer=init_ops.constant_initializer(self._r_on),
- trainable=self._trainable_r_on,
- dtype=dtype)
+ 0., period.initial_value))
+ ratio_on = vs.get_variable(
+ "ratio_on", [self._num_units],
+ initializer=init_ops.constant_initializer(self._ratio_on),
+ trainable=self._trainable_ratio_on)
- shifted_time = time - phase
- ph_time = self._mod(shifted_time, period) / period
+ cycle_ratio = self._get_cycle_ratio(time, phase, period)
- k_up = 2 * ph_time / r_on
+ k_up = 2 * cycle_ratio / ratio_on
k_down = 2 - k_up
- k_closed = self._leak * ph_time
+ k_closed = self._leak * cycle_ratio
- k = array_ops.where(ph_time < self._r_on, k_down, k_closed)
- k = array_ops.where(ph_time < 0.5 * self._r_on, k_up, k)
+ k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
+ k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
new_c = k * new_c + (1 - k) * c_prev
new_h = k * new_h + (1 - k) * h_prev