aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/estimator/python/estimator/rnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/estimator/python/estimator/rnn.py')
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn.py54
1 files changed, 2 insertions, 52 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
index 98660bb731..c595f47395 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import core as core_layers
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
@@ -92,55 +91,6 @@ def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'):
return rnn_cell_fn
-def _concatenate_context_input(sequence_input, context_input):
- """Replicates `context_input` across all timesteps of `sequence_input`.
-
- Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
- This value is appended to `sequence_input` on dimension 2 and the result is
- returned.
-
- Args:
- sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
- padded_length, d0]`.
- context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
-
- Returns:
- A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
- d0 + d1]`.
-
- Raises:
- ValueError: If `sequence_input` does not have rank 3 or `context_input` does
- not have rank 2.
- """
- seq_rank_check = check_ops.assert_rank(
- sequence_input,
- 3,
- message='sequence_input must have rank 3',
- data=[array_ops.shape(sequence_input)])
- seq_type_check = check_ops.assert_type(
- sequence_input,
- dtypes.float32,
- message='sequence_input must have dtype float32; got {}.'.format(
- sequence_input.dtype))
- ctx_rank_check = check_ops.assert_rank(
- context_input,
- 2,
- message='context_input must have rank 2',
- data=[array_ops.shape(context_input)])
- ctx_type_check = check_ops.assert_type(
- context_input,
- dtypes.float32,
- message='context_input must have dtype float32; got {}.'.format(
- context_input.dtype))
- with ops.control_dependencies(
- [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
- padded_length = array_ops.shape(sequence_input)[1]
- tiled_context_input = array_ops.tile(
- array_ops.expand_dims(context_input, 1),
- array_ops.concat([[1], [padded_length], [1]], 0))
- return array_ops.concat([sequence_input, tiled_context_input], 2)
-
-
def _select_last_activations(activations, sequence_lengths):
"""Selects the nth set of activations for each n in `sequence_length`.
@@ -222,8 +172,8 @@ def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns,
context_input = feature_column_lib.input_layer(
features=features,
feature_columns=context_feature_columns)
- sequence_input = _concatenate_context_input(sequence_input,
- context_input)
+ sequence_input = seq_fc.concatenate_context_input(
+ context_input, sequence_input)
cell = rnn_cell_fn(mode)
# Ignore output state.