aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/keras/backend.py14
-rw-r--r--tensorflow/python/keras/backend_test.py111
2 files changed, 120 insertions, 5 deletions
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index fed779650e..11f99c030f 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -3161,10 +3161,16 @@ def rnn(step_function,
array_ops.stack(
[1, array_ops.shape(output)[1]]))
output = array_ops.where(tiled_mask_t, output, states[0])
- new_states = [
- array_ops.where(tiled_mask_t, new_states[i], states[i])
- for i in range(len(states))
- ]
+
+ masked_states = []
+ for i in range(len(states)):
+ states_dim = array_ops.shape(new_states[i])[1]
+ stacked_states_dim = array_ops.stack([1, states_dim])
+ tiled_mask = array_ops.tile(mask_t, stacked_states_dim)
+ masked_state = array_ops.where(tiled_mask, new_states[i], states[i])
+ masked_states.append(masked_state)
+ new_states = masked_states
+
output_ta_t = output_ta_t.write(time, output)
return (time + 1, output_ta_t) + tuple(new_states)
else:
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index 2ba6c8ef15..0ddffa61a4 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -1077,7 +1077,7 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
{'go_backwards': False, 'mask': mask, 'unroll': True},
]
with self.test_session():
- for (i, kwargs) in enumerate(kwargs_list):
+ for i, kwargs in enumerate(kwargs_list):
last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
initial_states,
**kwargs)
@@ -1124,6 +1124,115 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
for b_s, b_u_s in zip(state_list[2], state_list[3]):
self.assertAllClose(b_s, b_u_s, atol=1e-04)
+ def test_rnn_additional_states(self):
+ # implement a simple RNN
+ num_samples = 4
+ input_dim = 5
+ output_dim = 3
+ timesteps = 6
+
+ input_val = np.random.random(
+ (num_samples, timesteps, input_dim)).astype(np.float32)
+ init_state_val = np.random.random(
+ (num_samples, output_dim)).astype(np.float32)
+ w_i_val = np.random.random((input_dim, output_dim)).astype(np.float32)
+ w_o_val = np.random.random((output_dim, output_dim)).astype(np.float32)
+ np_mask = np.random.randint(2, size=(num_samples, timesteps))
+
+ def rnn_step_fn():
+ w_i = keras.backend.variable(w_i_val)
+ w_o = keras.backend.variable(w_o_val)
+
+ def step_function(x, states):
+ assert len(states) == 2
+ prev_output = states[0]
+ output = keras.backend.dot(x, w_i) + keras.backend.dot(prev_output, w_o)
+ return output, [output,
+ keras.backend.concatenate([output, output], axis=-1)]
+
+ return step_function
+
+ # test default setup
+ last_output_list = [[], [], [], [], [], []]
+ outputs_list = [[], [], [], [], [], []]
+ state_list = [[], [], [], [], [], []]
+ additional_state_list = [[], [], [], [], [], []]
+
+ rnn_fn = rnn_step_fn()
+ inputs = keras.backend.variable(input_val)
+ initial_states = [keras.backend.variable(init_state_val),
+ np.concatenate([init_state_val, init_state_val], axis=-1)]
+ mask = keras.backend.variable(np_mask)
+
+ kwargs_list = [
+ {'go_backwards': False, 'mask': None},
+ {'go_backwards': False, 'mask': None, 'unroll': True},
+ {'go_backwards': True, 'mask': None},
+ {'go_backwards': True, 'mask': None, 'unroll': True},
+ {'go_backwards': False, 'mask': mask},
+ {'go_backwards': False, 'mask': mask, 'unroll': True},
+ ]
+ with self.test_session():
+ for i, kwargs in enumerate(kwargs_list):
+ last_output, outputs, new_states = keras.backend.rnn(rnn_fn, inputs,
+ initial_states,
+ **kwargs)
+ # check static shape inference
+ self.assertEqual(last_output.get_shape().as_list(),
+ [num_samples, output_dim])
+ self.assertEqual(outputs.get_shape().as_list(),
+ [num_samples, timesteps, output_dim])
+ # for state in new_states:
+ # self.assertEquals(state.get_shape().as_list(),
+ # [num_samples, output_dim])
+ self.assertEqual(new_states[0].get_shape().as_list(),
+ [num_samples, output_dim])
+ self.assertEqual(new_states[1].get_shape().as_list(),
+ [num_samples, 2 * output_dim])
+
+ last_output_list[i].append(keras.backend.eval(last_output))
+ outputs_list[i].append(keras.backend.eval(outputs))
+ self.assertEqual(len(new_states), 2)
+ state_list[i].append(keras.backend.eval(new_states[0]))
+ additional_state_list[i].append(keras.backend.eval(new_states[1]))
+
+ def assert_list_pairwise(z_list, atol=1e-05):
+ for (z1, z2) in zip(z_list[1:], z_list[:-1]):
+ self.assertAllClose(z1, z2, atol=atol)
+
+ assert_list_pairwise(last_output_list[0], atol=1e-04)
+ assert_list_pairwise(outputs_list[0], atol=1e-04)
+ assert_list_pairwise(state_list[0], atol=1e-04)
+ assert_list_pairwise(additional_state_list[0], atol=1e-04)
+ assert_list_pairwise(last_output_list[2], atol=1e-04)
+ assert_list_pairwise(outputs_list[2], atol=1e-04)
+ assert_list_pairwise(state_list[2], atol=1e-04)
+ assert_list_pairwise(additional_state_list[2], atol=1e-04)
+
+ for l, u_l in zip(last_output_list[0], last_output_list[1]):
+ self.assertAllClose(l, u_l, atol=1e-04)
+
+ for o, u_o in zip(outputs_list[0], outputs_list[1]):
+ self.assertAllClose(o, u_o, atol=1e-04)
+
+ for s, u_s in zip(state_list[0], state_list[1]):
+ self.assertAllClose(s, u_s, atol=1e-04)
+
+ for s, u_s in zip(additional_state_list[0], additional_state_list[1]):
+ self.assertAllClose(s, u_s, atol=1e-04)
+
+ for b_l, b_u_l in zip(last_output_list[2], last_output_list[3]):
+ self.assertAllClose(b_l, b_u_l, atol=1e-04)
+
+ for b_o, b_u_o in zip(outputs_list[2], outputs_list[3]):
+ self.assertAllClose(b_o, b_u_o, atol=1e-04)
+
+ for b_s, b_u_s in zip(state_list[2], state_list[3]):
+ self.assertAllClose(b_s, b_u_s, atol=1e-04)
+
+ for s, u_s in zip(additional_state_list[2], additional_state_list[3]):
+ self.assertAllClose(s, u_s, atol=1e-04)
+
def test_normalize_batch_in_training(self):
val = np.random.random((10, 3, 10, 10))
x = keras.backend.variable(val)