diff options
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 56 |
2 files changed, 31 insertions, 27 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index aed16b4dba..c840ce5227 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -859,7 +859,7 @@ class RNNCellTest(test.TestCase): [1.844123e-05, -2.159617e-05]], dtype=np.float32) with variable_scope.variable_scope("root"): - t = array_ops.zeros([batch_size, 1]) + t = array_ops.zeros([batch_size, 1], dtype=dtypes.float64) x = array_ops.zeros([batch_size, input_size]) c0 = array_ops.zeros([batch_size, 2]) h0 = array_ops.zeros([batch_size, 2]) diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 1632946ea6..2cd1814213 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1727,8 +1727,8 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): num_units, use_peepholes=False, leak=0.001, - r_on=0.1, - trainable_r_on=True, + ratio_on=0.1, + trainable_ratio_on=True, period_init_min=1.0, period_init_max=1000.0, reuse=None): @@ -1739,9 +1739,9 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): use_peepholes: bool, set True to enable peephole connections. leak: float or scalar float Tensor with value in [0, 1]. Leak applied during training. - r_on: float or scalar float Tensor with value in [0, 1]. Ratio of the + ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the period during which the gates are open. - trainable_r_on: bool, weather r_on is trainable. + trainable_ratio_on: bool, weather ratio_on is trainable. period_init_min: float or scalar float Tensor. With value > 0. Minimum value of the initalized period. The period values are initialized by drawing from the distribution: @@ -1756,8 +1756,8 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): self._num_units = num_units self._use_peepholes = use_peepholes self._leak = leak - self._r_on = r_on - self._trainable_r_on = trainable_r_on + self._ratio_on = ratio_on + self._trainable_ratio_on = trainable_ratio_on self._period_init_min = period_init_min self._period_init_max = period_init_max self._reuse = reuse @@ -1774,14 +1774,23 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): """Modulo function that propagates x gradients.""" return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x + def _get_cycle_ratio(self, time, phase, period): + """Compute the cycle ratio in the dtype of the time.""" + phase_casted = math_ops.cast(phase, dtype=time.dtype) + period_casted = math_ops.cast(period, dtype=time.dtype) + shifted_time = time - phase_casted + cycle_ratio = self._mod(shifted_time, period_casted) / period_casted + return math_ops.cast(cycle_ratio, dtype=dtypes.float32) + def __call__(self, inputs, state, scope=None): """Phased LSTM Cell. Args: - inputs: A tuple of 2 Tensor of type float32. - The first Tensor has shape [batch, 1], and stores the time. - The second Tesnsor has shape [batch, features_size], and stores the - features. + inputs: A tuple of 2 Tensor. + The first Tensor has shape [batch, 1], and type float32 or float64. + It stores the time. + The second Tensor has shape [batch, features_size], and type float32. + It stores the features. state: core_rnn_cell.LSTMStateTuple, state from previous timestep. scope: string, id of the variable scope. @@ -1795,7 +1804,6 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): with _checked_scope(self, scope or "phased_lstm_cell", reuse=self._reuse): (c_prev, h_prev) = state (time, x) = inputs - dtype = x.dtype in_mask_gates = [x, h_prev] if self._use_peepholes: @@ -1826,28 +1834,24 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell): period = vs.get_variable( "period", [self._num_units], initializer=_random_exp_initializer( - self._period_init_min, self._period_init_max), - dtype=dtype) + self._period_init_min, self._period_init_max)) phase = vs.get_variable( "phase", [self._num_units], initializer=init_ops.random_uniform_initializer( - 0., period.initial_value), - dtype=dtype) - r_on = vs.get_variable( - "r_on", [self._num_units], - initializer=init_ops.constant_initializer(self._r_on), - trainable=self._trainable_r_on, - dtype=dtype) + 0., period.initial_value)) + ratio_on = vs.get_variable( + "ratio_on", [self._num_units], + initializer=init_ops.constant_initializer(self._ratio_on), + trainable=self._trainable_ratio_on) - shifted_time = time - phase - ph_time = self._mod(shifted_time, period) / period + cycle_ratio = self._get_cycle_ratio(time, phase, period) - k_up = 2 * ph_time / r_on + k_up = 2 * cycle_ratio / ratio_on k_down = 2 - k_up - k_closed = self._leak * ph_time + k_closed = self._leak * cycle_ratio - k = array_ops.where(ph_time < self._r_on, k_down, k_closed) - k = array_ops.where(ph_time < 0.5 * self._r_on, k_up, k) + k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed) + k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k) new_c = k * new_c + (1 - k) * c_prev new_h = k * new_h + (1 - k) * h_prev |