aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column/feature_column_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/feature_column/feature_column_test.py')
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 6ac5ce8757..d974f14b8a 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -2258,10 +2258,6 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
fc.categorical_column_with_vocabulary_file(
key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=None)
- with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
- fc.categorical_column_with_vocabulary_file(
- key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=-1)
with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
fc.categorical_column_with_vocabulary_file(
@@ -2372,6 +2368,24 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_weight_pair.id_tensor.eval())
+ def test_get_sparse_tensors_none_vocabulary_size(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(
+ (2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
def test_transform_feature(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa',