diff options
-rw-r--r-- | tensorflow/python/keras/backend.py | 14 | ||||
-rw-r--r-- | tensorflow/python/keras/backend_test.py | 111 |
2 files changed, 120 insertions, 5 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index fed779650e..11f99c030f 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -3161,10 +3161,16 @@ def rnn(step_function, array_ops.stack( [1, array_ops.shape(output)[1]])) output = array_ops.where(tiled_mask_t, output, states[0]) - new_states = [ - array_ops.where(tiled_mask_t, new_states[i], states[i]) - for i in range(len(states)) - ] + + masked_states = [] + for i in range(len(states)): + states_dim = array_ops.shape(new_states[i])[1] + stacked_states_dim = array_ops.stack([1, states_dim]) + tiled_mask = array_ops.tile(mask_t, stacked_states_dim) + masked_state = array_ops.where(tiled_mask, new_states[i], states[i]) + masked_states.append(masked_state) + new_states = masked_states + output_ta_t = output_ta_t.write(time, output) return (time + 1, output_ta_t) + tuple(new_states) else: diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index 2ba6c8ef15..0ddffa61a4 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -1077,7 +1077,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase): {'go_backwards': False, 'mask': mask, 'unroll': True}, ] with self.test_session(): - for (i, kwargs) in enumerate(kwargs_list): + for i, kwargs in enumerate(kwargs_list): last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, initial_states, **kwargs) @@ -1124,6 +1124,115 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase): for b_s, b_u_s in zip(state_list[2], state_list[3]): self.assertAllClose(b_s, b_u_s, atol=1e-04) + def test_rnn_additional_states(self): + # implement a simple RNN + num_samples = 4 + input_dim = 5 + output_dim = 3 + timesteps = 6 + + input_val = np.random.random( + (num_samples, timesteps, input_dim)).astype(np.float32) + init_state_val = np.random.random( + (num_samples, output_dim)).astype(np.float32) + w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32) + w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32) + np_mask = np.random.randint(2, size=(num_samples, timesteps)) + + def rnn_step_fn(): + w_i = keras.backend.variable(w_i_val) + w_o = keras.backend.variable(w_o_val) + + def step_function(x, states): + assert len(states) == 2 + prev_output = states[0] + output = keras.backend.dot(x, w_i) + keras.backend.dot(prev_output, w_o) + return output, [output, + keras.backend.concatenate([output, output], axis=-1)] + + return step_function + + # test default setup + last_output_list = [[], [], [], [], [], []] + outputs_list = [[], [], [], [], [], []] + state_list = [[], [], [], [], [], []] + additional_state_list = [[], [], [], [], [], []] + + rnn_fn = rnn_step_fn() + inputs = keras.backend.variable(input_val) + initial_states = [keras.backend.variable(init_state_val), + np.concatenate([init_state_val, init_state_val], axis=-1)] + mask = keras.backend.variable(np_mask) + + kwargs_list = [ + {'go_backwards': False, 'mask': None}, + {'go_backwards': False, 'mask': None, 'unroll': True}, + {'go_backwards': True, 'mask': None}, + {'go_backwards': True, 'mask': None, 'unroll': True}, + {'go_backwards': False, 'mask': mask}, + {'go_backwards': False, 'mask': mask, 'unroll': True}, + ] + with self.test_session(): + for i, kwargs in enumerate(kwargs_list): + last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs, + initial_states, + **kwargs) + # check static shape inference + self.assertEqual(last_output.get_shape().as_list(), + [num_samples, output_dim]) + self.assertEqual(outputs.get_shape().as_list(), + [num_samples, timesteps, output_dim]) + # for state in new_states: + # self.assertEquals(state.get_shape().as_list(), + # [num_samples, output_dim]) + self.assertEqual(new_states[0].get_shape().as_list(), + [num_samples, output_dim]) + self.assertEqual(new_states[1].get_shape().as_list(), + [num_samples, 2 * output_dim]) + + last_output_list[i].append(keras.backend.eval(last_output)) + outputs_list[i].append(keras.backend.eval(outputs)) + self.assertEqual(len(new_states), 2) + state_list[i].append(keras.backend.eval(new_states[0])) + additional_state_list[i].append(keras.backend.eval(new_states[1])) + + def assert_list_pairwise(z_list, atol=1e-05): + for (z1, z2) in zip(z_list[1:], z_list[:-1]): + self.assertAllClose(z1, z2, atol=atol) + + assert_list_pairwise(last_output_list[0], atol=1e-04) + assert_list_pairwise(outputs_list[0], atol=1e-04) + assert_list_pairwise(state_list[0], atol=1e-04) + assert_list_pairwise(additional_state_list[0], atol=1e-04) + assert_list_pairwise(last_output_list[2], atol=1e-04) + assert_list_pairwise(outputs_list[2], atol=1e-04) + assert_list_pairwise(state_list[2], atol=1e-04) + assert_list_pairwise(additional_state_list[2], atol=1e-04) + + for l, u_l in zip(last_output_list[0], last_output_list[1]): + self.assertAllClose(l, u_l, atol=1e-04) + + for o, u_o in zip(outputs_list[0], outputs_list[1]): + self.assertAllClose(o, u_o, atol=1e-04) + + for s, u_s in zip(state_list[0], state_list[1]): + self.assertAllClose(s, u_s, atol=1e-04) + + for s, u_s in zip(additional_state_list[0], additional_state_list[1]): + self.assertAllClose(s, u_s, atol=1e-04) + + for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]): + self.assertAllClose(b_l, b_u_l, atol=1e-04) + + for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]): + self.assertAllClose(b_o, b_u_o, atol=1e-04) + + for b_s, b_u_s in zip(state_list[2], state_list[3]): + self.assertAllClose(b_s, b_u_s, atol=1e-04) + + for s, u_s in zip(additional_state_list[2], additional_state_list[3]): + self.assertAllClose(s, u_s, atol=1e-04) + def test_normalize_batch_in_training(self): val = np.random.random((10, 3, 10, 10)) x = keras.backend.variable(val) |