aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/feature_column
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2018-03-12 19:33:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-12 19:37:39 -0700
commit7144571f2fc59c8705e4e3d7b922fa0ebf44f3fa (patch)
treeb14683f826541c183c1bb783265e13b565469fbb /tensorflow/contrib/feature_column
parent2bda52d485c9715dcd17f49526cea7890e091cb8 (diff)
Merge changes from github.
PiperOrigin-RevId: 188817194
Diffstat (limited to 'tensorflow/contrib/feature_column')
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py325
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py471
2 files changed, 796 insertions, 0 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py
new file mode 100644
index 0000000000..4ed7268e7a
--- /dev/null
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py
@@ -0,0 +1,325 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental methods for tf.feature_column sequence input."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import abc
+import collections
+
+
+from tensorflow.python.feature_column import feature_column as fc
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import variable_scope
+
+# TODO(b/73160931): Fix pydoc.
+# pylint: disable=g-doc-args,missing-docstring,protected-access
+# TODO(b/73827486): Support SequenceExample.
+
+
+def sequence_input_layer(
+ features,
+ feature_columns,
+ weight_collections=None,
+ trainable=True,
+ scope=None):
+ """"Builds input layer for sequence input.
+
+ All `feature_columns` must be sequence dense columns with the same
+ `sequence_length`. The output of this method can be fed into sequence
+ networks, such as RNN.
+
+ The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`.
+ `T` is the maximum sequence length for this batch, which could differ from
+ batch to batch.
+
+ If multiple `feature_columns` are given with `Di` `num_elements` each, their
+ outputs are concatenated. So, the final `Tensor` has shape
+ `[batch_size, T, D0 + D1 + ... + Dn]`.
+
+ Example:
+
+ ```python
+ rating = sequence_numeric_column('rating')
+ watches = sequence_categorical_column_with_identity(
+ 'watches', num_buckets=1000)
+ watches_embedding = embedding_column(watches, dimension=10)
+ columns = [rating, watches]
+
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
+ input_layer, sequence_length = sequence_input_layer(features, columns)
+
+ rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
+ outputs, state = tf.nn.dynamic_rnn(
+ rnn_cell, inputs=input_layer, sequence_length=sequence_length)
+ ```
+
+ Returns:
+ An `(input_layer, sequence_length)` tuple where:
+ - input_layer: A float `Tensor` of shape `[batch_size, T, D]`.
+ `T` is the maximum sequence length for this batch, which could differ
+ from batch to batch. `D` is the sum of `num_elements` for all
+ `feature_columns`.
+ - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence
+ length for each example.
+ Raises:
+ ValueError: If any of the `feature_columns` is the wrong type.
+ """
+ feature_columns = fc._clean_feature_columns(feature_columns)
+ for c in feature_columns:
+ if not isinstance(c, _SequenceDenseColumn):
+ raise ValueError(
+ 'All feature_columns must be of type _SequenceDenseColumn. '
+ 'Given (type {}): {}'.format(type(c), c))
+
+ with variable_scope.variable_scope(
+ scope, default_name='sequence_input_layer', values=features.values()):
+ builder = fc._LazyBuilder(features)
+ output_tensors = []
+ sequence_lengths = []
+ ordered_columns = []
+ for column in sorted(feature_columns, key=lambda x: x.name):
+ ordered_columns.append(column)
+ with variable_scope.variable_scope(
+ None, default_name=column._var_scope_name):
+ dense_tensor, sequence_length = column._get_sequence_dense_tensor(
+ builder,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ # Flattens the final dimension to produce a 3D Tensor.
+ num_elements = column._variable_shape.num_elements()
+ shape = array_ops.shape(dense_tensor)
+ output_tensors.append(
+ array_ops.reshape(
+ dense_tensor,
+ shape=array_ops.concat([shape[:2], [num_elements]], axis=0)))
+ sequence_lengths.append(sequence_length)
+ fc._verify_static_batch_size_equality(output_tensors, ordered_columns)
+ # TODO(b/73160931): Verify sequence_length equality.
+ return array_ops.concat(output_tensors, -1), sequence_lengths[0]
+
+
+# TODO(b/73160931): Add remaining categorical columns.
+def sequence_categorical_column_with_identity(
+ key, num_buckets, default_value=None):
+ return _SequenceCategoricalColumn(
+ fc.categorical_column_with_identity(
+ key=key,
+ num_buckets=num_buckets,
+ default_value=default_value))
+
+
+# TODO(b/73160931): Merge with embedding_column
+def _sequence_embedding_column(
+ categorical_column, dimension, initializer=None, ckpt_to_load_from=None,
+ tensor_name_in_ckpt=None, max_norm=None, trainable=True):
+ if not isinstance(categorical_column, _SequenceCategoricalColumn):
+ raise ValueError(
+ 'categorical_column must be of type _SequenceCategoricalColumn. '
+ 'Given (type {}): {}'.format(
+ type(categorical_column), categorical_column))
+ return _SequenceEmbeddingColumn(
+ fc.embedding_column(
+ categorical_column,
+ dimension=dimension,
+ initializer=initializer,
+ ckpt_to_load_from=ckpt_to_load_from,
+ tensor_name_in_ckpt=tensor_name_in_ckpt,
+ max_norm=max_norm,
+ trainable=trainable))
+
+
+def sequence_numeric_column(
+ key,
+ shape=(1,),
+ default_value=0.,
+ dtype=dtypes.float32):
+ # TODO(b/73160931): Add validations.
+ return _SequenceNumericColumn(
+ key,
+ shape=shape,
+ default_value=default_value,
+ dtype=dtype)
+
+
+class _SequenceDenseColumn(fc._FeatureColumn):
+ """Represents dense sequence data."""
+
+ __metaclass__ = abc.ABCMeta
+
+ TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name
+ 'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length'])
+
+ @abc.abstractproperty
+ def _variable_shape(self):
+ """`TensorShape` without batch and sequence dimensions."""
+ pass
+
+ @abc.abstractmethod
+ def _get_sequence_dense_tensor(
+ self, inputs, weight_collections=None, trainable=None):
+ """Returns a `TensorSequenceLengthPair`."""
+ pass
+
+
+def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
+ with ops.name_scope(None, 'sequence_length') as name_scope:
+ row_ids = sp_tensor.indices[:, 0]
+ column_ids = sp_tensor.indices[:, 1]
+ column_ids += array_ops.ones_like(column_ids)
+ seq_length = (
+ math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
+ # If the last n rows do not have ids, seq_length will have shape
+ # [batch_size - n]. Pad the remaining values with zeros.
+ n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
+ padding = array_ops.zeros(n_pad, dtype=seq_length.dtype)
+ return array_ops.concat([seq_length, padding], axis=0, name=name_scope)
+
+
+class _SequenceCategoricalColumn(
+ fc._CategoricalColumn,
+ collections.namedtuple(
+ '_SequenceCategoricalColumn', ['categorical_column'])):
+
+ @property
+ def name(self):
+ return self.categorical_column.name
+
+ @property
+ def _parse_example_spec(self):
+ return self.categorical_column._parse_example_spec
+
+ def _transform_feature(self, inputs):
+ return self.categorical_column._transform_feature(inputs)
+
+ @property
+ def _num_buckets(self):
+ return self.categorical_column._num_buckets
+
+ def _get_sparse_tensors(self, inputs, weight_collections=None,
+ trainable=None):
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)
+ id_tensor = sparse_tensors.id_tensor
+ weight_tensor = sparse_tensors.weight_tensor
+ # Expands final dimension, so that embeddings are not combined during
+ # embedding lookup.
+ check_id_rank = check_ops.assert_equal(
+ array_ops.rank(id_tensor), 2,
+ data=[
+ 'Column {} expected ID tensor of rank 2. '.format(self.name),
+ 'id_tensor shape: ', array_ops.shape(id_tensor)])
+ with ops.control_dependencies([check_id_rank]):
+ id_tensor = sparse_ops.sparse_reshape(
+ id_tensor,
+ shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
+ if weight_tensor is not None:
+ check_weight_rank = check_ops.assert_equal(
+ array_ops.rank(weight_tensor), 2,
+ data=[
+ 'Column {} expected weight tensor of rank 2.'.format(self.name),
+ 'weight_tensor shape:', array_ops.shape(weight_tensor)])
+ with ops.control_dependencies([check_weight_rank]):
+ weight_tensor = sparse_ops.sparse_reshape(
+ weight_tensor,
+ shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
+ return fc._CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
+
+ def _sequence_length(self, inputs):
+ sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)
+ return _sequence_length_from_sparse_tensor(sparse_tensors.id_tensor)
+
+
+class _SequenceEmbeddingColumn(
+ _SequenceDenseColumn,
+ collections.namedtuple('_SequenceEmbeddingColumn', ['embedding_column'])):
+
+ @property
+ def name(self):
+ return self.embedding_column.name
+
+ @property
+ def _parse_example_spec(self):
+ return self.embedding_column._parse_example_spec
+
+ def _transform_feature(self, inputs):
+ return self.embedding_column._transform_feature(inputs)
+
+ @property
+ def _variable_shape(self):
+ return self.embedding_column._variable_shape
+
+ def _get_sequence_dense_tensor(
+ self, inputs, weight_collections=None, trainable=None):
+ dense_tensor = self.embedding_column._get_dense_tensor(
+ inputs=inputs,
+ weight_collections=weight_collections,
+ trainable=trainable)
+ sequence_length = self.embedding_column.categorical_column._sequence_length(
+ inputs)
+ return _SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+
+class _SequenceNumericColumn(
+ _SequenceDenseColumn,
+ collections.namedtuple(
+ '_SequenceNumericColumn',
+ ['key', 'shape', 'default_value', 'dtype'])):
+
+ @property
+ def name(self):
+ return self.key
+
+ @property
+ def _parse_example_spec(self):
+ return {self.key: parsing_ops.VarLenFeature(self.dtype)}
+
+ def _transform_feature(self, inputs):
+ return inputs.get(self.key)
+
+ @property
+ def _variable_shape(self):
+ return tensor_shape.TensorShape(self.shape)
+
+ def _get_sequence_dense_tensor(
+ self, inputs, weight_collections=None, trainable=None):
+ # Do nothing with weight_collections and trainable since no variables are
+ # created in this function.
+ del weight_collections
+ del trainable
+ sp_tensor = inputs.get(self)
+ dense_tensor = sparse_ops.sparse_tensor_to_dense(
+ sp_tensor, default_value=self.default_value)
+ # Reshape into [batch_size, T, variable_shape].
+ dense_shape = array_ops.concat(
+ [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape],
+ axis=0)
+ dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape)
+ sequence_length = _sequence_length_from_sparse_tensor(
+ sp_tensor, num_elements=self._variable_shape.num_elements())
+ return _SequenceDenseColumn.TensorSequenceLengthPair(
+ dense_tensor=dense_tensor, sequence_length=sequence_length)
+
+# pylint: enable=g-doc-args,missing-docstring,protected-access
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py
new file mode 100644
index 0000000000..59674869a2
--- /dev/null
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py
@@ -0,0 +1,471 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for sequential_feature_column."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.feature_column.python.feature_column import sequential_feature_column as sfc
+from tensorflow.python.feature_column.feature_column import _LazyBuilder
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+from tensorflow.python.training import monitored_session
+
+
+class SequenceInputLayerTest(test.TestCase):
+
+ def test_embedding_column(self):
+ vocabulary_size = 3
+ sparse_input_a = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+ sparse_input_b = sparse_tensor.SparseTensorValue(
+ # example 0, ids [1]
+ # example 1, ids [2, 0]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 0),
+ dense_shape=(2, 2))
+
+ embedding_dimension_a = 2
+ embedding_values_a = (
+ (1., 2.), # id 0
+ (3., 4.), # id 1
+ (5., 6.) # id 2
+ )
+ embedding_dimension_b = 3
+ embedding_values_b = (
+ (11., 12., 13.), # id 0
+ (14., 15., 16.), # id 1
+ (17., 18., 19.) # id 2
+ )
+ def _get_initializer(embedding_dimension, embedding_values):
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+ return _initializer
+
+ expected_input_layer = [
+ # example 0, ids_a [2], ids_b [1]
+ [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]],
+ # example 1, ids_a [0, 1], ids_b [2, 0]
+ [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],
+ ]
+ expected_sequence_length = [1, 2]
+
+ categorical_column_a = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column_a = sfc._sequence_embedding_column(
+ categorical_column_a, dimension=embedding_dimension_a,
+ initializer=_get_initializer(embedding_dimension_a, embedding_values_a))
+ categorical_column_b = sfc.sequence_categorical_column_with_identity(
+ key='bbb', num_buckets=vocabulary_size)
+ embedding_column_b = sfc._sequence_embedding_column(
+ categorical_column_b, dimension=embedding_dimension_b,
+ initializer=_get_initializer(embedding_dimension_b, embedding_values_b))
+
+ input_layer, sequence_length = sfc.sequence_input_layer(
+ features={
+ 'aaa': sparse_input_a,
+ 'bbb': sparse_input_b,
+ },
+ # Test that columns are reordered alphabetically.
+ feature_columns=[embedding_column_b, embedding_column_a])
+
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('sequence_input_layer/aaa_embedding/embedding_weights:0',
+ 'sequence_input_layer/bbb_embedding/embedding_weights:0'),
+ tuple([v.name for v in global_vars]))
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess))
+ self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess))
+ self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess))
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+ def test_numeric_column(self):
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values [[0.], [1]]
+ # example 1, [[10.]]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2))
+ expected_input_layer = [
+ [[0.], [1.]],
+ [[10.], [0.]],
+ ]
+ expected_sequence_length = [2, 1]
+ numeric_column = sfc.sequence_numeric_column('aaa')
+
+ input_layer, sequence_length = sfc.sequence_input_layer(
+ features={'aaa': sparse_input},
+ feature_columns=[numeric_column])
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess))
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+ def test_numeric_column_multi_dim(self):
+ """Tests sequence_input_layer for multi-dimensional numeric_column."""
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
+ # example 1, [[[10., 11.], [12., 13.]]]
+ indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7),
+ (1, 0), (1, 1), (1, 2), (1, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 8))
+ # The output of numeric_column._get_dense_tensor should be flattened.
+ expected_input_layer = [
+ [[0., 1., 2., 3.], [4., 5., 6., 7.]],
+ [[10., 11., 12., 13.], [0., 0., 0., 0.]],
+ ]
+ expected_sequence_length = [2, 1]
+ numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
+
+ input_layer, sequence_length = sfc.sequence_input_layer(
+ features={'aaa': sparse_input},
+ feature_columns=[numeric_column])
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess))
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+
+def _assert_sparse_tensor_value(test_case, expected, actual):
+ test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
+ test_case.assertAllEqual(expected.indices, actual.indices)
+
+ test_case.assertEqual(
+ np.array(expected.values).dtype, np.array(actual.values).dtype)
+ test_case.assertAllEqual(expected.values, actual.values)
+
+ test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
+ test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
+
+
+class SequenceCategoricalColumnWithIdentityTest(test.TestCase):
+
+ def test_get_sparse_tensors(self):
+ column = sfc.sequence_categorical_column_with_identity(
+ 'aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 0),
+ dense_shape=(2, 2))
+ expected_sparse_ids = sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ values=np.array((1, 2, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))
+
+ id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
+
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with monitored_session.MonitoredSession() as sess:
+ _assert_sparse_tensor_value(
+ self,
+ expected_sparse_ids,
+ id_weight_pair.id_tensor.eval(session=sess))
+
+ def test_get_sparse_tensors_inputs3d(self):
+ """Tests _get_sparse_tensors when the input is already 3D Tensor."""
+ column = sfc.sequence_categorical_column_with_identity(
+ 'aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ values=(1, 2, 0),
+ dense_shape=(2, 2, 1))
+
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r'Column aaa expected ID tensor of rank 2\.\s*'
+ r'id_tensor shape:\s*\[2 2 1\]'):
+ id_weight_pair = column._get_sparse_tensors(
+ _LazyBuilder({'aaa': inputs}))
+ with monitored_session.MonitoredSession() as sess:
+ id_weight_pair.id_tensor.eval(session=sess)
+
+ def test_sequence_length(self):
+ column = sfc.sequence_categorical_column_with_identity(
+ 'aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 0),
+ dense_shape=(2, 2))
+ expected_sequence_length = [1, 2]
+
+ sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+ def test_sequence_length_with_zeros(self):
+ column = sfc.sequence_categorical_column_with_identity(
+ 'aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((1, 0), (3, 0), (3, 1)),
+ values=(1, 2, 0),
+ dense_shape=(5, 2))
+ expected_sequence_length = [0, 1, 0, 2, 0]
+
+ sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+
+class SequenceEmbeddingColumnTest(test.TestCase):
+
+ def test_get_sequence_dense_tensor(self):
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 1), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 2))
+
+ embedding_dimension = 2
+ embedding_values = (
+ (1., 2.), # id 0
+ (3., 5.), # id 1
+ (7., 11.) # id 2
+ )
+ def _initializer(shape, dtype, partition_info):
+ self.assertAllEqual((vocabulary_size, embedding_dimension), shape)
+ self.assertEqual(dtypes.float32, dtype)
+ self.assertIsNone(partition_info)
+ return embedding_values
+
+ expected_lookups = [
+ # example 0, ids [2]
+ [[7., 11.], [0., 0.]],
+ # example 1, ids [0, 1]
+ [[1., 2.], [3., 5.]],
+ # example 2, ids []
+ [[0., 0.], [0., 0.]],
+ # example 3, ids [1]
+ [[3., 5.], [0., 0.]],
+ ]
+
+ categorical_column = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = sfc._sequence_embedding_column(
+ categorical_column, dimension=embedding_dimension,
+ initializer=_initializer)
+
+ embedding_lookup, _ = embedding_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+ self.assertItemsEqual(
+ ('embedding_weights:0',), tuple([v.name for v in global_vars]))
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess))
+ self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess))
+
+ def test_sequence_length(self):
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2))
+ expected_sequence_length = [1, 2]
+
+ categorical_column = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = sfc._sequence_embedding_column(
+ categorical_column, dimension=2)
+
+ _, sequence_length = embedding_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+ def test_sequence_length_with_empty_rows(self):
+ """Tests _sequence_length when some examples do not have ids."""
+ vocabulary_size = 3
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, ids []
+ # example 1, ids [2]
+ # example 2, ids [0, 1]
+ # example 3, ids []
+ # example 4, ids [1]
+ # example 5, ids []
+ indices=((1, 0), (2, 0), (2, 1), (4, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(6, 2))
+ expected_sequence_length = [0, 1, 2, 0, 1, 0]
+
+ categorical_column = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=vocabulary_size)
+ embedding_column = sfc._sequence_embedding_column(
+ categorical_column, dimension=2)
+
+ _, sequence_length = embedding_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+
+class SequenceNumericColumnTest(test.TestCase):
+
+ def test_get_sequence_dense_tensor(self):
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values [[0.], [1]]
+ # example 1, [[10.]]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2))
+ expected_dense_tensor = [
+ [[0.], [1.]],
+ [[10.], [0.]],
+ ]
+ numeric_column = sfc.sequence_numeric_column('aaa')
+
+ dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_dense_tensor, dense_tensor.eval(session=sess))
+
+ def test_get_sequence_dense_tensor_with_shape(self):
+ """Tests get_sequence_dense_tensor with shape !=(1,)."""
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values [[0., 1., 2.], [3., 4., 5.]]
+ # example 1, [[10., 11., 12.]]
+ indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5),
+ (1, 0), (1, 1), (1, 2)),
+ values=(0., 1., 2., 3., 4., 5., 10., 11., 12.),
+ dense_shape=(2, 6))
+ expected_dense_tensor = [
+ [[0., 1., 2.], [3., 4., 5.]],
+ [[10., 11., 12.], [0., 0., 0.]],
+ ]
+ numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,))
+
+ dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_dense_tensor, dense_tensor.eval(session=sess))
+
+ def test_get_dense_tensor_multi_dim(self):
+ """Tests get_sequence_dense_tensor for multi-dim numeric_column."""
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
+ # example 1, [[[10., 11.], [12., 13.]]]
+ indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7),
+ (1, 0), (1, 1), (1, 2), (1, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 8))
+ expected_dense_tensor = [
+ [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]],
+ [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]],
+ ]
+ numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
+
+ dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_dense_tensor, dense_tensor.eval(session=sess))
+
+ def test_sequence_length(self):
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values [[0., 1., 2.], [3., 4., 5.]]
+ # example 1, [[10., 11., 12.]]
+ indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5),
+ (1, 0), (1, 1), (1, 2)),
+ values=(0., 1., 2., 3., 4., 5., 10., 11., 12.),
+ dense_shape=(2, 6))
+ expected_sequence_length = [2, 1]
+ numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,))
+
+ _, sequence_length = numeric_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+ def test_sequence_length_with_shape(self):
+ """Tests _sequence_length with shape !=(1,)."""
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values [[0.], [1]]
+ # example 1, [[10.]]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2))
+ expected_sequence_length = [2, 1]
+ numeric_column = sfc.sequence_numeric_column('aaa')
+
+ _, sequence_length = numeric_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+ def test_sequence_length_with_empty_rows(self):
+ """Tests _sequence_length when some examples do not have ids."""
+ sparse_input = sparse_tensor.SparseTensorValue(
+ # example 0, values []
+ # example 1, values [[0.], [1.]]
+ # example 2, [[2.]]
+ # example 3, values []
+ # example 4, [[3.]]
+ # example 5, values []
+ indices=((1, 0), (1, 1), (2, 0), (4, 0)),
+ values=(0., 1., 2., 3.),
+ dense_shape=(6, 2))
+ expected_sequence_length = [0, 2, 1, 0, 1, 0]
+ numeric_column = sfc.sequence_numeric_column('aaa')
+
+ _, sequence_length = numeric_column._get_sequence_dense_tensor(
+ _LazyBuilder({'aaa': sparse_input}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+
+if __name__ == '__main__':
+ test.main()