diff options
author | 2017-11-24 22:18:53 -0800 | |
---|---|---|
committer | 2017-11-24 22:22:30 -0800 | |
commit | 93bce00552ac70cc2c9b72e5742f9de87d72985a (patch) | |
tree | 12334c3cda2971b21499a381d23165ac8c720be3 /tensorflow/python/feature_column | |
parent | 080e432f2bd5566946887ef383acf0b5d34d150a (diff) |
Accept None vocabulary_size to categorical_column_with_vocabulary_file()
Defaults to the length of the given vocabulary file.
PiperOrigin-RevId: 176881510
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 27 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_test.py | 22 |
2 files changed, 39 insertions, 10 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 452f84192c..0686480ca4 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -152,6 +152,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import checkpoint_utils from tensorflow.python.util import nest @@ -980,9 +981,12 @@ def categorical_column_with_hash_bucket(key, return _HashedCategoricalColumn(key, hash_bucket_size, dtype) -def categorical_column_with_vocabulary_file( - key, vocabulary_file, vocabulary_size, num_oov_buckets=0, - default_value=None, dtype=dtypes.string): +def categorical_column_with_vocabulary_file(key, + vocabulary_file, + vocabulary_size=None, + num_oov_buckets=0, + default_value=None, + dtype=dtypes.string): """A `_CategoricalColumn` with a vocabulary file. Use this when your inputs are in string or integer format, and you have a @@ -1041,7 +1045,7 @@ def categorical_column_with_vocabulary_file( vocabulary_file: The vocabulary file name. vocabulary_size: Number of the elements in the vocabulary. This must be no greater than length of `vocabulary_file`, if less than length, later - values are ignored. + values are ignored. If None, it is set to the length of `vocabulary_file`. num_oov_buckets: Non-negative integer, the number of out-of-vocabulary buckets. All out-of-vocabulary inputs will be assigned IDs in the range `[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of @@ -1056,7 +1060,7 @@ def categorical_column_with_vocabulary_file( A `_CategoricalColumn` with a vocabulary file. Raises: - ValueError: `vocabulary_file` is missing. + ValueError: `vocabulary_file` is missing or cannot be opened. ValueError: `vocabulary_size` is missing or < 1. ValueError: `num_oov_buckets` is a negative integer. ValueError: `num_oov_buckets` and `default_value` are both specified. @@ -1064,8 +1068,19 @@ def categorical_column_with_vocabulary_file( """ if not vocabulary_file: raise ValueError('Missing vocabulary_file in {}.'.format(key)) + + if vocabulary_size is None: + if not gfile.Exists(vocabulary_file): + raise ValueError('vocabulary_file in {} does not exist.'.format(key)) + + with gfile.GFile(vocabulary_file) as f: + vocabulary_size = sum(1 for _ in f) + logging.info( + 'vocabulary_size = %d in %s is inferred from the number of elements ' + 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file) + # `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`. - if (vocabulary_size is None) or (vocabulary_size < 1): + if vocabulary_size < 1: raise ValueError('Invalid vocabulary_size in {}.'.format(key)) if num_oov_buckets: if default_value is not None: diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 6ac5ce8757..d974f14b8a 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -2258,10 +2258,6 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file=self._wire_vocabulary_file_name, - vocabulary_size=None) - with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): - fc.categorical_column_with_vocabulary_file( - key='aaa', vocabulary_file=self._wire_vocabulary_file_name, vocabulary_size=-1) with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'): fc.categorical_column_with_vocabulary_file( @@ -2372,6 +2368,24 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): dense_shape=inputs.dense_shape), id_weight_pair.id_tensor.eval()) + def test_get_sparse_tensors_none_vocabulary_size(self): + column = fc.categorical_column_with_vocabulary_file( + key='aaa', vocabulary_file=self._wire_vocabulary_file_name) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('marlo', 'skywalker', 'omar'), + dense_shape=(2, 2)) + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value(self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array( + (2, -1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + def test_transform_feature(self): column = fc.categorical_column_with_vocabulary_file( key='aaa', |