aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py')
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py72
1 files changed, 65 insertions, 7 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
index 05bcdac2ca..dd6da35ed0 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
@@ -33,7 +33,6 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
# pylint: disable=protected-access
-# TODO(b/73827486): Support SequenceExample.
def sequence_input_layer(
@@ -110,6 +109,7 @@ def sequence_input_layer(
output_tensors = []
sequence_lengths = []
ordered_columns = []
+
for column in sorted(feature_columns, key=lambda x: x.name):
ordered_columns.append(column)
with variable_scope.variable_scope(
@@ -121,17 +121,67 @@ def sequence_input_layer(
# Flattens the final dimension to produce a 3D Tensor.
num_elements = column._variable_shape.num_elements()
shape = array_ops.shape(dense_tensor)
+ target_shape = [shape[0], shape[1], num_elements]
output_tensors.append(
- array_ops.reshape(
- dense_tensor,
- shape=array_ops.concat([shape[:2], [num_elements]], axis=0)))
+ array_ops.reshape(dense_tensor, shape=target_shape))
sequence_lengths.append(sequence_length)
+
fc._verify_static_batch_size_equality(output_tensors, ordered_columns)
fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns)
sequence_length = _assert_all_equal_and_return(sequence_lengths)
+
return array_ops.concat(output_tensors, -1), sequence_length
+def concatenate_context_input(context_input, sequence_input):
+ """Replicates `context_input` across all timesteps of `sequence_input`.
+
+ Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
+ This value is appended to `sequence_input` on dimension 2 and the result is
+ returned.
+
+ Args:
+ context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
+ sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
+ padded_length, d0]`.
+
+ Returns:
+ A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
+ d0 + d1]`.
+
+ Raises:
+ ValueError: If `sequence_input` does not have rank 3 or `context_input` does
+ not have rank 2.
+ """
+ seq_rank_check = check_ops.assert_rank(
+ sequence_input,
+ 3,
+ message='sequence_input must have rank 3',
+ data=[array_ops.shape(sequence_input)])
+ seq_type_check = check_ops.assert_type(
+ sequence_input,
+ dtypes.float32,
+ message='sequence_input must have dtype float32; got {}.'.format(
+ sequence_input.dtype))
+ ctx_rank_check = check_ops.assert_rank(
+ context_input,
+ 2,
+ message='context_input must have rank 2',
+ data=[array_ops.shape(context_input)])
+ ctx_type_check = check_ops.assert_type(
+ context_input,
+ dtypes.float32,
+ message='context_input must have dtype float32; got {}.'.format(
+ context_input.dtype))
+ with ops.control_dependencies(
+ [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
+ padded_length = array_ops.shape(sequence_input)[1]
+ tiled_context_input = array_ops.tile(
+ array_ops.expand_dims(context_input, 1),
+ array_ops.concat([[1], [padded_length], [1]], 0))
+ return array_ops.concat([sequence_input, tiled_context_input], 2)
+
+
def sequence_categorical_column_with_identity(
key, num_buckets, default_value=None):
"""Returns a feature column that represents sequences of integers.
@@ -453,9 +503,17 @@ class _SequenceNumericColumn(
[array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape],
axis=0)
dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape)
- sequence_length = fc._sequence_length_from_sparse_tensor(
- sp_tensor, num_elements=self._variable_shape.num_elements())
+
+ # Get the number of timesteps per example
+ # For the 2D case, the raw values are grouped according to num_elements;
+ # for the 3D case, the grouping happens in the third dimension, and
+ # sequence length is not affected.
+ num_elements = (self._variable_shape.num_elements()
+ if sp_tensor.shape.ndims == 2 else 1)
+ seq_length = fc._sequence_length_from_sparse_tensor(
+ sp_tensor, num_elements=num_elements)
+
return fc._SequenceDenseColumn.TensorSequenceLengthPair(
- dense_tensor=dense_tensor, sequence_length=sequence_length)
+ dense_tensor=dense_tensor, sequence_length=seq_length)
# pylint: enable=protected-access