diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-03-20 07:57:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-20 09:05:00 -0700 |
commit | 957253b638391bcf08dc335904f207a746d32721 (patch) | |
tree | 1e5a75f1e513d6cab0c84e5fc75cf2dd3d095c55 | |
parent | 63d8889cb303f9e58123ac1a266f494af31b43f4 (diff) |
Pass `sequence_length` to `dynamic_rnn()`.
Change: 150630961
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py | 9 |
2 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index d44780b282..ba839c66a3 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -204,6 +204,7 @@ def build_sequence_input(features, def construct_rnn(initial_state, sequence_input, + sequence_length, cell, num_label_columns, dtype=dtypes.float32, @@ -216,6 +217,8 @@ def construct_rnn(initial_state, default starting state for `self._cell` is used. sequence_input: A `Tensor` with shape `[batch_size, padded_length, d]` that will be passed as input to the RNN. + sequence_length: A `Tensor` with shape `[batch_size]` indicating the length + of the sequence input for each example. cell: An initialized `RNNCell`. num_label_columns: The desired output dimension. dtype: dtype of `cell`. @@ -236,6 +239,7 @@ def construct_rnn(initial_state, rnn_outputs, final_state = rnn.dynamic_rnn( cell=cell, inputs=sequence_input, + sequence_length=sequence_length, initial_state=initial_state, dtype=dtype, parallel_iterations=parallel_iterations, @@ -526,6 +530,7 @@ def _get_dynamic_rnn_model_fn( rnn_activations, final_state = construct_rnn( initial_state, sequence_input, + sequence_length, cell, target_column.num_label_columns, dtype=dtype, diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index c0d5933d48..45262ea6bc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -142,7 +142,8 @@ class DynamicRnnEstimatorTest(test.TestCase): dense_shape=[3, 2, 2]), 'measurements': random_ops.random_uniform( - [3, 2, 2], seed=4711) + [3, 2, 2], seed=4711), + 'sequence_length': constant_op.constant([2, 2, 1], dtype=dtypes.int32) } def GetClassificationTargetsOrNone(self, mode): @@ -168,11 +169,13 @@ class DynamicRnnEstimatorTest(test.TestCase): def testConstructRNN(self): initial_state = None + columns_to_tensors = self.GetColumnsToTensors() + sequence_length = columns_to_tensors.pop('sequence_length') sequence_input = dynamic_rnn_estimator.build_sequence_input( - self.GetColumnsToTensors(), self.sequence_feature_columns, + columns_to_tensors, self.sequence_feature_columns, self.context_feature_columns) activations_t, final_state_t = dynamic_rnn_estimator.construct_rnn( - initial_state, sequence_input, self.rnn_cell, + initial_state, sequence_input, sequence_length, self.rnn_cell, self.mock_target_column.num_label_columns) # Obtain values of activations and final state. |