diff options
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/rnn.py')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/rnn.py | 54 |
1 files changed, 2 insertions, 52 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py index 98660bb731..c595f47395 100644 --- a/tensorflow/contrib/estimator/python/estimator/rnn.py +++ b/tensorflow/contrib/estimator/python/estimator/rnn.py @@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.layers import core as core_layers from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables @@ -92,55 +91,6 @@ def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'): return rnn_cell_fn -def _concatenate_context_input(sequence_input, context_input): - """Replicates `context_input` across all timesteps of `sequence_input`. - - Expands dimension 1 of `context_input` then tiles it `sequence_length` times. - This value is appended to `sequence_input` on dimension 2 and the result is - returned. - - Args: - sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size, - padded_length, d0]`. - context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`. - - Returns: - A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, - d0 + d1]`. - - Raises: - ValueError: If `sequence_input` does not have rank 3 or `context_input` does - not have rank 2. - """ - seq_rank_check = check_ops.assert_rank( - sequence_input, - 3, - message='sequence_input must have rank 3', - data=[array_ops.shape(sequence_input)]) - seq_type_check = check_ops.assert_type( - sequence_input, - dtypes.float32, - message='sequence_input must have dtype float32; got {}.'.format( - sequence_input.dtype)) - ctx_rank_check = check_ops.assert_rank( - context_input, - 2, - message='context_input must have rank 2', - data=[array_ops.shape(context_input)]) - ctx_type_check = check_ops.assert_type( - context_input, - dtypes.float32, - message='context_input must have dtype float32; got {}.'.format( - context_input.dtype)) - with ops.control_dependencies( - [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]): - padded_length = array_ops.shape(sequence_input)[1] - tiled_context_input = array_ops.tile( - array_ops.expand_dims(context_input, 1), - array_ops.concat([[1], [padded_length], [1]], 0)) - return array_ops.concat([sequence_input, tiled_context_input], 2) - - def _select_last_activations(activations, sequence_lengths): """Selects the nth set of activations for each n in `sequence_length`. @@ -222,8 +172,8 @@ def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns, context_input = feature_column_lib.input_layer( features=features, feature_columns=context_feature_columns) - sequence_input = _concatenate_context_input(sequence_input, - context_input) + sequence_input = seq_fc.concatenate_context_input( + context_input, sequence_input) cell = rnn_cell_fn(mode) # Ignore output state. |