aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Makoto Uchida <muchida@google.com>2017-11-24 22:18:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-24 22:22:30 -0800
commit93bce00552ac70cc2c9b72e5742f9de87d72985a (patch)
tree12334c3cda2971b21499a381d23165ac8c720be3 /tensorflow/python/feature_column
parent080e432f2bd5566946887ef383acf0b5d34d150a (diff)
Accept None vocabulary_size to categorical_column_with_vocabulary_file()
Defaults to the length of the given vocabulary file. PiperOrigin-RevId: 176881510
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py27
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py22
2 files changed, 39 insertions, 10 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 452f84192c..0686480ca4 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -152,6 +152,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_utils
from tensorflow.python.util import nest
@@ -980,9 +981,12 @@ def categorical_column_with_hash_bucket(key,
return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
-def categorical_column_with_vocabulary_file(
- key, vocabulary_file, vocabulary_size, num_oov_buckets=0,
- default_value=None, dtype=dtypes.string):
+def categorical_column_with_vocabulary_file(key,
+ vocabulary_file,
+ vocabulary_size=None,
+ num_oov_buckets=0,
+ default_value=None,
+ dtype=dtypes.string):
"""A `_CategoricalColumn` with a vocabulary file.
Use this when your inputs are in string or integer format, and you have a
@@ -1041,7 +1045,7 @@ def categorical_column_with_vocabulary_file(
vocabulary_file: The vocabulary file name.
vocabulary_size: Number of the elements in the vocabulary. This must be no
greater than length of `vocabulary_file`, if less than length, later
- values are ignored.
+ values are ignored. If None, it is set to the length of `vocabulary_file`.
num_oov_buckets: Non-negative integer, the number of out-of-vocabulary
buckets. All out-of-vocabulary inputs will be assigned IDs in the range
`[vocabulary_size, vocabulary_size+num_oov_buckets)` based on a hash of
@@ -1056,7 +1060,7 @@ def categorical_column_with_vocabulary_file(
A `_CategoricalColumn` with a vocabulary file.
Raises:
- ValueError: `vocabulary_file` is missing.
+ ValueError: `vocabulary_file` is missing or cannot be opened.
ValueError: `vocabulary_size` is missing or < 1.
ValueError: `num_oov_buckets` is a negative integer.
ValueError: `num_oov_buckets` and `default_value` are both specified.
@@ -1064,8 +1068,19 @@ def categorical_column_with_vocabulary_file(
"""
if not vocabulary_file:
raise ValueError('Missing vocabulary_file in {}.'.format(key))
+
+ if vocabulary_size is None:
+ if not gfile.Exists(vocabulary_file):
+ raise ValueError('vocabulary_file in {} does not exist.'.format(key))
+
+ with gfile.GFile(vocabulary_file) as f:
+ vocabulary_size = sum(1 for _ in f)
+ logging.info(
+ 'vocabulary_size = %d in %s is inferred from the number of elements '
+ 'in the vocabulary_file %s.', vocabulary_size, key, vocabulary_file)
+
# `vocabulary_size` isn't required for lookup, but it is for `_num_buckets`.
- if (vocabulary_size is None) or (vocabulary_size < 1):
+ if vocabulary_size < 1:
raise ValueError('Invalid vocabulary_size in {}.'.format(key))
if num_oov_buckets:
if default_value is not None:
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 6ac5ce8757..d974f14b8a 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -2258,10 +2258,6 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
fc.categorical_column_with_vocabulary_file(
key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=None)
- with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
- fc.categorical_column_with_vocabulary_file(
- key='aaa', vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=-1)
with self.assertRaisesRegexp(ValueError, 'Invalid vocabulary_size'):
fc.categorical_column_with_vocabulary_file(
@@ -2372,6 +2368,24 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
dense_shape=inputs.dense_shape),
id_weight_pair.id_tensor.eval())
+ def test_get_sparse_tensors_none_vocabulary_size(self):
+ column = fc.categorical_column_with_vocabulary_file(
+ key='aaa', vocabulary_file=self._wire_vocabulary_file_name)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2))
+ id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array(
+ (2, -1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
def test_transform_feature(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa',