aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-06 11:23:41 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-06 11:30:53 -0800
commitc6feeafaabb09bdcda3e34009506c5dae596c5d9 (patch)
treeaf3ef92ac8139a57998759a285a01649989ec033 /tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
parent131f13afafd59278d4441f61f5f6e231b48f077c (diff)
Sequence versions of remaining categorical columns
PiperOrigin-RevId: 188051821
Diffstat (limited to 'tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py')
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py148
1 files changed, 146 insertions, 2 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index 5c1e76fc62..c077f03291 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
import numpy as np
from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc
@@ -230,13 +231,17 @@ class SequenceInputLayerTest(test.TestCase):
def _assert_sparse_tensor_value(test_case, expected, actual):
- test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
- test_case.assertAllEqual(expected.indices, actual.indices)
+ _assert_sparse_tensor_indices_shape(test_case, expected, actual)
test_case.assertEqual(
np.array(expected.values).dtype, np.array(actual.values).dtype)
test_case.assertAllEqual(expected.values, actual.values)
+
+def _assert_sparse_tensor_indices_shape(test_case, expected, actual):
+ test_case.assertEqual(np.int64, np.array(actual.indices).dtype)
+ test_case.assertAllEqual(expected.indices, actual.indices)
+
test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype)
test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
@@ -314,6 +319,145 @@ class SequenceCategoricalColumnWithIdentityTest(test.TestCase):
expected_sequence_length, sequence_length.eval(session=sess))
+class SequenceCategoricalColumnWithHashBucketTest(test.TestCase):
+
+ def test_get_sparse_tensors(self):
+ column = sfc.sequence_categorical_column_with_hash_bucket(
+ 'aaa', hash_bucket_size=10)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+
+ expected_sparse_ids = sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ # Ignored to avoid hash dependence in test.
+ values=np.array((0, 0, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))
+
+ id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
+
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with monitored_session.MonitoredSession() as sess:
+ _assert_sparse_tensor_indices_shape(
+ self,
+ expected_sparse_ids,
+ id_weight_pair.id_tensor.eval(session=sess))
+
+ def test_sequence_length(self):
+ column = sfc.sequence_categorical_column_with_hash_bucket(
+ 'aaa', hash_bucket_size=10)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ expected_sequence_length = [1, 2]
+
+ sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+
+class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase):
+
+ def _write_vocab(self, vocab_strings, file_name):
+ vocab_file = os.path.join(self.get_temp_dir(), file_name)
+ with open(vocab_file, 'w') as f:
+ f.write('\n'.join(vocab_strings))
+ return vocab_file
+
+ def setUp(self):
+ super(SequenceCategoricalColumnWithVocabularyFileTest, self).setUp()
+
+ vocab_strings = ['omar', 'stringer', 'marlo']
+ self._wire_vocabulary_file_name = self._write_vocab(vocab_strings,
+ 'wire_vocabulary.txt')
+ self._wire_vocabulary_size = 3
+
+ def test_get_sparse_tensors(self):
+ column = sfc.sequence_categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ expected_sparse_ids = sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))
+
+ id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
+
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with monitored_session.MonitoredSession() as sess:
+ _assert_sparse_tensor_value(
+ self,
+ expected_sparse_ids,
+ id_weight_pair.id_tensor.eval(session=sess))
+
+ def test_sequence_length(self):
+ column = sfc.sequence_categorical_column_with_vocabulary_file(
+ key='aaa',
+ vocabulary_file=self._wire_vocabulary_file_name,
+ vocabulary_size=self._wire_vocabulary_size)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ expected_sequence_length = [1, 2]
+
+ sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+
+class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase):
+
+ def test_get_sparse_tensors(self):
+ column = sfc.sequence_categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ expected_sparse_ids = sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))
+
+ id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
+
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with monitored_session.MonitoredSession() as sess:
+ _assert_sparse_tensor_value(
+ self,
+ expected_sparse_ids,
+ id_weight_pair.id_tensor.eval(session=sess))
+
+ def test_sequence_length(self):
+ column = sfc.sequence_categorical_column_with_vocabulary_list(
+ key='aaa',
+ vocabulary_list=('omar', 'stringer', 'marlo'))
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ expected_sequence_length = [1, 2]
+
+ sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs}))
+
+ with monitored_session.MonitoredSession() as sess:
+ self.assertAllEqual(
+ expected_sequence_length, sequence_length.eval(session=sess))
+
+
class SequenceEmbeddingColumnTest(test.TestCase):
def test_get_sequence_dense_tensor(self):