aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-20 07:57:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-20 09:05:00 -0700
commit957253b638391bcf08dc335904f207a746d32721 (patch)
tree1e5a75f1e513d6cab0c84e5fc75cf2dd3d095c55
parent63d8889cb303f9e58123ac1a266f494af31b43f4 (diff)
Pass `sequence_length` to `dynamic_rnn()`.
Change: 150630961
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py5
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py9
2 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
index d44780b282..ba839c66a3 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py
@@ -204,6 +204,7 @@ def build_sequence_input(features,
def construct_rnn(initial_state,
sequence_input,
+ sequence_length,
cell,
num_label_columns,
dtype=dtypes.float32,
@@ -216,6 +217,8 @@ def construct_rnn(initial_state,
default starting state for `self._cell` is used.
sequence_input: A `Tensor` with shape `[batch_size, padded_length, d]`
that will be passed as input to the RNN.
+ sequence_length: A `Tensor` with shape `[batch_size]` indicating the length
+ of the sequence input for each example.
cell: An initialized `RNNCell`.
num_label_columns: The desired output dimension.
dtype: dtype of `cell`.
@@ -236,6 +239,7 @@ def construct_rnn(initial_state,
rnn_outputs, final_state = rnn.dynamic_rnn(
cell=cell,
inputs=sequence_input,
+ sequence_length=sequence_length,
initial_state=initial_state,
dtype=dtype,
parallel_iterations=parallel_iterations,
@@ -526,6 +530,7 @@ def _get_dynamic_rnn_model_fn(
rnn_activations, final_state = construct_rnn(
initial_state,
sequence_input,
+ sequence_length,
cell,
target_column.num_label_columns,
dtype=dtype,
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
index c0d5933d48..45262ea6bc 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py
@@ -142,7 +142,8 @@ class DynamicRnnEstimatorTest(test.TestCase):
dense_shape=[3, 2, 2]),
'measurements':
random_ops.random_uniform(
- [3, 2, 2], seed=4711)
+ [3, 2, 2], seed=4711),
+ 'sequence_length': constant_op.constant([2, 2, 1], dtype=dtypes.int32)
}
def GetClassificationTargetsOrNone(self, mode):
@@ -168,11 +169,13 @@ class DynamicRnnEstimatorTest(test.TestCase):
def testConstructRNN(self):
initial_state = None
+ columns_to_tensors = self.GetColumnsToTensors()
+ sequence_length = columns_to_tensors.pop('sequence_length')
sequence_input = dynamic_rnn_estimator.build_sequence_input(
- self.GetColumnsToTensors(), self.sequence_feature_columns,
+ columns_to_tensors, self.sequence_feature_columns,
self.context_feature_columns)
activations_t, final_state_t = dynamic_rnn_estimator.construct_rnn(
- initial_state, sequence_input, self.rnn_cell,
+ initial_state, sequence_input, sequence_length, self.rnn_cell,
self.mock_target_column.num_label_columns)
# Obtain values of activations and final state.