aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/recurrent
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-07 14:31:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 14:39:52 -0700
commit9b84c91b68915a26aa9d732988cbf13a7626c2dd (patch)
tree371bf3f94eb07f38a988649d5e72f405d2aa832c /tensorflow/contrib/recurrent
parentcf233f7281fbb841f5ea2df548e6e857308f222b (diff)
Fix the output shape of functional_rnn for time-major inputs.
PiperOrigin-RevId: 207780606
Diffstat (limited to 'tensorflow/contrib/recurrent')
-rw-r--r--tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py158
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py9
2 files changed, 149 insertions, 18 deletions
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py
index 0f19ac7dbe..f23194a6f2 100644
--- a/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py
+++ b/tensorflow/contrib/recurrent/python/kernel_tests/functional_rnn_test.py
@@ -61,10 +61,17 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
func, args = self._CELLDEFS[celldef_name]
return func(*args)
- def _CreateInputs(self):
- inputs = np.random.random([FunctionalRnnTest._BATCH_SIZE,
- FunctionalRnnTest._TOTAL_TIME,
- FunctionalRnnTest._INPUT_SIZE])
+ def _CreateInputs(self, time_major=False):
+ if time_major:
+ inputs = np.random.random([
+ FunctionalRnnTest._TOTAL_TIME, FunctionalRnnTest._BATCH_SIZE,
+ FunctionalRnnTest._INPUT_SIZE
+ ])
+ else:
+ inputs = np.random.random([
+ FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._TOTAL_TIME,
+ FunctionalRnnTest._INPUT_SIZE
+ ])
# Always leave one time slot empty, to check max_length behavior.
sequence_length = np.random.randint(
0, high=FunctionalRnnTest._TOTAL_TIME - 1,
@@ -72,15 +79,51 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
dtype=np.int)
return (inputs, sequence_length)
- def _CreateRnnGraph(self, create_rnn_computation_func, cell, tf_inputs,
- tf_sequence_length, initial_state=None,
- time_major=None, scope=None):
- tf_result = create_rnn_computation_func(cell=cell, inputs=tf_inputs,
- sequence_length=tf_sequence_length,
- initial_state=initial_state,
- dtype=dtypes.float32,
- time_major=time_major,
- scope=scope)
+ def _CreateSymmetricInputs(self):
+ # total time = batch size
+ inputs = np.zeros(
+ (FunctionalRnnTest._BATCH_SIZE, FunctionalRnnTest._BATCH_SIZE,
+ FunctionalRnnTest._INPUT_SIZE))
+ for i in range(FunctionalRnnTest._BATCH_SIZE):
+ for j in range(i, FunctionalRnnTest._BATCH_SIZE):
+ inputs[i][j] = np.random.random([FunctionalRnnTest._INPUT_SIZE])
+ inputs[j][i] = inputs[i][j]
+
+ # Always leave one time slot empty, to check max_length behavior.
+ sequence_length = np.random.randint(
+ 0,
+ high=FunctionalRnnTest._BATCH_SIZE - 1,
+ size=FunctionalRnnTest._BATCH_SIZE,
+ dtype=np.int)
+ return (inputs, sequence_length)
+
+ def _CreateRnnGraph(self,
+ create_rnn_computation_func,
+ cell,
+ tf_inputs,
+ tf_sequence_length,
+ is_bidirectional,
+ initial_state=None,
+ time_major=None,
+ scope=None):
+ if is_bidirectional:
+ tf_result = create_rnn_computation_func(
+ cell_fw=cell,
+ cell_bw=cell,
+ inputs=tf_inputs,
+ sequence_length=tf_sequence_length,
+ dtype=dtypes.float32,
+ time_major=time_major,
+ scope=scope)
+ else:
+ tf_result = create_rnn_computation_func(
+ cell=cell,
+ inputs=tf_inputs,
+ sequence_length=tf_sequence_length,
+ initial_state=initial_state,
+ dtype=dtypes.float32,
+ time_major=time_major,
+ scope=scope)
grad = gradients_impl.gradients(tf_result, variables.trainable_variables())
return {'inference': tf_result, 'grad': grad}
@@ -102,15 +145,26 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
variable_cache[n] = v
def _RunRnn(self, numpy_inputs, numpy_slen, cell_name, variable_cache,
- is_dynamic):
+ is_dynamic, time_major=None, is_bidirectional=False):
with ops.Graph().as_default() as graph:
tf_inputs = array_ops.placeholder(
dtypes.float32, shape=numpy_inputs.shape)
tf_slen = array_ops.placeholder(dtypes.int32)
feeds = {tf_inputs: numpy_inputs, tf_slen: numpy_slen}
cell = self._CreateCell(cell_name)
- fn = rnn_lib.dynamic_rnn if is_dynamic else functional_rnn.functional_rnn
- fetches = self._CreateRnnGraph(fn, cell, tf_inputs, tf_slen)
+ if is_dynamic:
+ if is_bidirectional:
+ fn = rnn_lib.bidirectional_dynamic_rnn
+ else:
+ fn = rnn_lib.dynamic_rnn
+ else:
+ if is_bidirectional:
+ fn = functional_rnn.bidirectional_functional_rnn
+ else:
+ fn = functional_rnn.functional_rnn
+
+ fetches = self._CreateRnnGraph(
+ fn, cell, tf_inputs, tf_slen, is_bidirectional, time_major=time_major)
with self.test_session(graph=graph) as sess:
sess.run(variables.global_variables_initializer())
# Note that cell.trainable_variables it not always set.
@@ -158,6 +212,78 @@ class FunctionalRnnTest(test_util.TensorFlowTestCase):
self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
+ def testLstmWithTimeMajorInputs(self):
+ """Checks an LSTM against the reference implementation, with time_major."""
+ time_major = True
+ np_inputs, np_slen = self._CreateInputs(time_major=True)
+ var_cache = {}
+ args = [np_inputs, np_slen, 'lstm', var_cache]
+ _, func_rnn = self._RunRnn(*(args + [False]), time_major=time_major)
+ _, dyn_rnn = self._RunRnn(*(args + [True]), time_major=time_major)
+ self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
+ self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
+
+ def testBidirectionalLstmWithTimeMajorInputs(self):
+ """Checks a bi-directional LSTM with time-major inputs."""
+ time_major = True
+ np_inputs, np_slen = self._CreateInputs(time_major)
+ var_cache = {}
+ args = [np_inputs, np_slen, 'lstm', var_cache]
+ _, func_rnn = self._RunRnn(
+ *(args + [False]), time_major=time_major, is_bidirectional=True)
+ _, dyn_rnn = self._RunRnn(
+ *(args + [True]), time_major=time_major, is_bidirectional=True)
+ self.assertAllClose(dyn_rnn['inference'], func_rnn['inference'])
+ # TODO(b/112170761): comment out this line after the bug is fixed.
+ # self.assertAllClose(dyn_rnn['grad'], func_rnn['grad'])
+
+ def testBidirectionalLstm(self):
+ """Checks time-major and batch-major rnn produce consistent results."""
+ time_major_inputs, np_slen = self._CreateInputs(True)
+ batch_major_inputs = np.transpose(time_major_inputs, [1, 0, 2])
+ var_cache = {}
+ args = [np_slen, 'lstm', var_cache, False]
+ _, time_major_rnn = self._RunRnn(
+ *([time_major_inputs] + args), time_major=True, is_bidirectional=True)
+ _, batch_major_rnn = self._RunRnn(
+ *([batch_major_inputs]+ args), time_major=False, is_bidirectional=True)
+ # Convert the batch-major outputs to be time-major before the comparasion.
+ outputs, state = batch_major_rnn['inference']
+ outputs = [np.transpose(x, [1, 0, 2]) for x in outputs]
+ batch_major_rnn['inference'] = [outputs, state]
+ self.assertAllClose(time_major_rnn['inference'],
+ batch_major_rnn['inference'])
+ self.assertAllClose(time_major_rnn['grad'], batch_major_rnn['grad'])
+
+ def testBidirectionalLstmWithSymmetricInputs(self):
+ """Checks a bi-directional LSTM with symmetric inputs.
+
+ time-major and batch-major rnn produce the same result with symmetric
+ inputs.
+ """
+ np_inputs, np_slen = self._CreateSymmetricInputs()
+ var_cache = {}
+ args = [np_inputs, np_slen, 'lstm', var_cache]
+ _, time_major_func_rnn = self._RunRnn(
+ *(args + [False]), time_major=True, is_bidirectional=True)
+ _, batch_major_func_rnn = self._RunRnn(
+ *(args + [False]), time_major=False, is_bidirectional=True)
+ _, time_major_dyn_rnn = self._RunRnn(
+ *(args + [True]), time_major=True, is_bidirectional=True)
+ _, batch_major_dyn_rnn = self._RunRnn(
+ *(args + [True]), time_major=False, is_bidirectional=True)
+ self.assertAllClose(time_major_func_rnn['inference'],
+ batch_major_func_rnn['inference'])
+ self.assertAllClose(time_major_func_rnn['grad'],
+ batch_major_func_rnn['grad'])
+ self.assertAllClose(time_major_dyn_rnn['inference'],
+ batch_major_dyn_rnn['inference'])
+ self.assertAllClose(time_major_dyn_rnn['grad'], batch_major_dyn_rnn['grad'])
+ self.assertAllClose(time_major_func_rnn['inference'],
+ batch_major_dyn_rnn['inference'])
+ self.assertAllClose(time_major_func_rnn['grad'],
+ batch_major_dyn_rnn['grad'])
+
if __name__ == '__main__':
test_lib.main()
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index a085474c1b..96cc3e997f 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -284,8 +284,13 @@ def functional_rnn(cell, inputs, sequence_length=None,
inputs=inputs,
cell_fn=func_cell.cell_step,
use_tpu=use_tpu)
- return _PostProcessOutput(extended_acc_state, extended_final_state,
- func_cell, inputs_flat[0].shape[0], sequence_length)
+ tf_output, tf_state = _PostProcessOutput(
+ extended_acc_state, extended_final_state, func_cell,
+ inputs_flat[0].shape[0], sequence_length)
+
+ if time_major:
+ tf_output = array_ops.transpose(tf_output, [1, 0, 2])
+ return tf_output, tf_state
def bidirectional_functional_rnn(