diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-07 14:31:49 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-07 14:39:52 -0700 |
commit | 9b84c91b68915a26aa9d732988cbf13a7626c2dd (patch) | |
tree | 371bf3f94eb07f38a988649d5e72f405d2aa832c /tensorflow/contrib/recurrent | |
parent | cf233f7281fbb841f5ea2df548e6e857308f222b (diff) |
Fix the output shape of functional_rnn for time-major inputs.
PiperOrigin-RevId: 207780606
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r-- | tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py | 158 | ||||
-rw-r--r-- | tensorflow/contrib/recurrent/python/ops/functional_rnn.py | 9 |
2 files changed, 149 insertions, 18 deletions
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py index 0f19ac7dbe..f23194a6f2 100644 --- a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py +++ b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py @@ -61,10 +61,17 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): func, args = self._CELLDEFS[celldef_name] return func(*args) - def _CreateInputs(self): - inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE, - FunctionalRnnTest._TOTAL_TIME, - FunctionalRnnTest._INPUT_SIZE]) + def _CreateInputs(self, time_major=False): + if time_major: + inputs = np.random.random([ + FunctionalRnnTest._TOTAL_TIME, FunctionalRnnTest._BATCH_SIZE, + FunctionalRnnTest._INPUT_SIZE + ]) + else: + inputs = np.random.random([ + FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._TOTAL_TIME, + FunctionalRnnTest._INPUT_SIZE + ]) # Always leave one time slot empty, to check max_length behavior. sequence_length = np.random.randint( 0, high=FunctionalRnnTest._TOTAL_TIME - 1, @@ -72,15 +79,51 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): dtype=np.int) return (inputs, sequence_length) - def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs, - tf_sequence_length, initial_state=None, - time_major=None, scope=None): - tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs, - sequence_length=tf_sequence_length, - initial_state=initial_state, - dtype=dtypes.float32, - time_major=time_major, - scope=scope) + def _CreateSymmetricInputs(self): + # total time = batch size + inputs = np.zeros( + (FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._BATCH_SIZE, + FunctionalRnnTest._INPUT_SIZE)) + for i in range(FunctionalRnnTest._BATCH_SIZE): + for j in range(i, FunctionalRnnTest._BATCH_SIZE): + inputs[i][j] = np.random.random([FunctionalRnnTest._INPUT_SIZE]) + inputs[j][i] = inputs[i][j] + + # Always leave one time slot empty, to check max_length behavior. + sequence_length = np.random.randint( + 0, + high=FunctionalRnnTest._BATCH_SIZE - 1, + size=FunctionalRnnTest._BATCH_SIZE, + dtype=np.int) + return (inputs, sequence_length) + + def _CreateRnnGraph(self, + create_rnn_computation_func, + cell, + tf_inputs, + tf_sequence_length, + is_bidirectional, + initial_state=None, + time_major=None, + scope=None): + if is_bidirectional: + tf_result = create_rnn_computation_func( + cell_fw=cell, + cell_bw=cell, + inputs=tf_inputs, + sequence_length=tf_sequence_length, + dtype=dtypes.float32, + time_major=time_major, + scope=scope) + else: + tf_result = create_rnn_computation_func( + cell=cell, + inputs=tf_inputs, + sequence_length=tf_sequence_length, + initial_state=initial_state, + dtype=dtypes.float32, + time_major=time_major, + scope=scope) grad = gradients_impl.gradients(tf_result, variables.trainable_variables()) return {'inference': tf_result, 'grad': grad} @@ -102,15 +145,26 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): variable_cache[n] = v def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache, - is_dynamic): + is_dynamic, time_major=None, is_bidirectional=False): with ops.Graph().as_default() as graph: tf_inputs = array_ops.placeholder( dtypes.float32, shape=numpy_inputs.shape) tf_slen = array_ops.placeholder(dtypes.int32) feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen} cell = self._CreateCell(cell_name) - fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn - fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen) + if is_dynamic: + if is_bidirectional: + fn = rnn_lib.bidirectional_dynamic_rnn + else: + fn = rnn_lib.dynamic_rnn + else: + if is_bidirectional: + fn = functional_rnn.bidirectional_functional_rnn + else: + fn = functional_rnn.functional_rnn + + fetches = self._CreateRnnGraph( + fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major) with self.test_session(graph=graph) as sess: sess.run(variables.global_variables_initializer()) # Note that cell.trainable_variables it not always set. @@ -158,6 +212,78 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase): self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + def testLstmWithTimeMajorInputs(self): + """Checks an LSTM against the reference implementation, with time_major.""" + time_major = True + np_inputs, np_slen = self._CreateInputs(time_major=True) + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, func_rnn = self._RunRnn(*(args + [False]), time_major=time_major) + _, dyn_rnn = self._RunRnn(*(args + [True]), time_major=time_major) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + def testBidirectionalLstmWithTimeMajorInputs(self): + """Checks a bi-directional LSTM with time-major inputs.""" + time_major = True + np_inputs, np_slen = self._CreateInputs(time_major) + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, func_rnn = self._RunRnn( + *(args + [False]), time_major=time_major, is_bidirectional=True) + _, dyn_rnn = self._RunRnn( + *(args + [True]), time_major=time_major, is_bidirectional=True) + self.assertAllClose(dyn_rnn['inference'], func_rnn['inference']) + # TODO(b/112170761): comment out this line after the bug is fixed. + # self.assertAllClose(dyn_rnn['grad'], func_rnn['grad']) + + def testBidirectionalLstm(self): + """Checks time-major and batch-major rnn produce consistent results.""" + time_major_inputs, np_slen = self._CreateInputs(True) + batch_major_inputs = np.transpose(time_major_inputs, [1, 0, 2]) + var_cache = {} + args = [np_slen, 'lstm', var_cache, False] + _, time_major_rnn = self._RunRnn( + *([time_major_inputs] + args), time_major=True, is_bidirectional=True) + _, batch_major_rnn = self._RunRnn( + *([batch_major_inputs]+ args), time_major=False, is_bidirectional=True) + # Convert the batch-major outputs to be time-major before the comparasion. + outputs, state = batch_major_rnn['inference'] + outputs = [np.transpose(x, [1, 0, 2]) for x in outputs] + batch_major_rnn['inference'] = [outputs, state] + self.assertAllClose(time_major_rnn['inference'], + batch_major_rnn['inference']) + self.assertAllClose(time_major_rnn['grad'], batch_major_rnn['grad']) + + def testBidirectionalLstmWithSymmetricInputs(self): + """Checks a bi-directional LSTM with symmetric inputs. + + time-major and batch-major rnn produce the same result with symmetric + inputs. + """ + np_inputs, np_slen = self._CreateSymmetricInputs() + var_cache = {} + args = [np_inputs, np_slen, 'lstm', var_cache] + _, time_major_func_rnn = self._RunRnn( + *(args + [False]), time_major=True, is_bidirectional=True) + _, batch_major_func_rnn = self._RunRnn( + *(args + [False]), time_major=False, is_bidirectional=True) + _, time_major_dyn_rnn = self._RunRnn( + *(args + [True]), time_major=True, is_bidirectional=True) + _, batch_major_dyn_rnn = self._RunRnn( + *(args + [True]), time_major=False, is_bidirectional=True) + self.assertAllClose(time_major_func_rnn['inference'], + batch_major_func_rnn['inference']) + self.assertAllClose(time_major_func_rnn['grad'], + batch_major_func_rnn['grad']) + self.assertAllClose(time_major_dyn_rnn['inference'], + batch_major_dyn_rnn['inference']) + self.assertAllClose(time_major_dyn_rnn['grad'], batch_major_dyn_rnn['grad']) + self.assertAllClose(time_major_func_rnn['inference'], + batch_major_dyn_rnn['inference']) + self.assertAllClose(time_major_func_rnn['grad'], + batch_major_dyn_rnn['grad']) + if __name__ == '__main__': test_lib.main() diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py index a085474c1b..96cc3e997f 100644 --- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py +++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py @@ -284,8 +284,13 @@ def functional_rnn(cell, inputs, sequence_length=None, inputs=inputs, cell_fn=func_cell.cell_step, use_tpu=use_tpu) - return _PostProcessOutput(extended_acc_state, extended_final_state, - func_cell, inputs_flat[0].shape[0], sequence_length) + tf_output, tf_state = _PostProcessOutput( + extended_acc_state, extended_final_state, func_cell, + inputs_flat[0].shape[0], sequence_length) + + if time_major: + tf_output = array_ops.transpose(tf_output, [1, 0, 2]) + return tf_output, tf_state def bidirectional_functional_rnn( |