aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar Mustafa Ispir <ispir@google.com>2018-05-14 22:04:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-14 22:07:42 -0700
commit2869a86c56b163318cfb47126f3c7f56db0b642c (patch)
tree364cf310e8dd686bae720f1c413a3cd3d5f97811 /tensorflow/python/feature_column
parentfdf36165090f465cc2464de26c939237c45155a3 (diff)
Added type check to feature column keys. So that users will get meaningful error messages in situations like: #19219
PiperOrigin-RevId: 196616638
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py12
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py22
2 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index c16c3cda48..1d50892a88 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -1068,6 +1068,7 @@ def numeric_column(key,
raise TypeError(
'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
+ _assert_key_is_string(key)
return _NumericColumn(
key,
shape=shape,
@@ -1166,6 +1167,13 @@ def _assert_string_or_int(dtype, prefix):
'{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
+def _assert_key_is_string(key):
+ if not isinstance(key, six.string_types):
+ raise ValueError(
+ 'key must be a string. Got: type {}. Given key: {}.'.format(
+ type(key), key))
+
+
@tf_export('feature_column.categorical_column_with_hash_bucket')
def categorical_column_with_hash_bucket(key,
hash_bucket_size,
@@ -1218,6 +1226,7 @@ def categorical_column_with_hash_bucket(key,
'hash_bucket_size: {}, key: {}'.format(
hash_bucket_size, key))
+ _assert_key_is_string(key)
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
@@ -1334,6 +1343,7 @@ def categorical_column_with_vocabulary_file(key,
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
num_oov_buckets, key))
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ _assert_key_is_string(key)
return _VocabularyFileCategoricalColumn(
key=key,
vocabulary_file=vocabulary_file,
@@ -1448,6 +1458,7 @@ def categorical_column_with_vocabulary_list(
'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
dtype, vocabulary_dtype, key))
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ _assert_key_is_string(key)
return _VocabularyListCategoricalColumn(
key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype,
@@ -1518,6 +1529,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
raise ValueError(
'default_value {} not in range [0, {}), column_name {}'.format(
default_value, num_buckets, key))
+ _assert_key_is_string(key)
return _IdentityCategoricalColumn(
key=key, num_buckets=num_buckets, default_value=default_value)
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index b06540489f..03c47eea31 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -182,6 +182,10 @@ class NumericColumnTest(test.TestCase):
self.assertEqual(dtypes.float32, a.dtype)
self.assertIsNone(a.normalizer_fn)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.numeric_column(key=('aaa',))
+
def test_shape_saved_as_tuple(self):
a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
self.assertEqual((1, 2), a.shape)
@@ -645,6 +649,10 @@ class HashedCategoricalColumnTest(test.TestCase):
self.assertEqual(10, a.hash_bucket_size)
self.assertEqual(dtypes.string, a.dtype)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_hash_bucket(('key',), 10)
+
def test_bucket_size_should_be_given(self):
with self.assertRaisesRegexp(ValueError, 'hash_bucket_size must be set.'):
fc.categorical_column_with_hash_bucket('aaa', None)
@@ -3327,6 +3335,11 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
'aaa': parsing_ops.VarLenFeature(dtypes.string)
}, column._parse_example_spec)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_vocabulary_file(
+ key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3)
+
def test_all_constructor_args(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
@@ -3752,6 +3765,11 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
'aaa': parsing_ops.VarLenFeature(dtypes.string)
}, column._parse_example_spec)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_vocabulary_list(
+ key=('aaa',), vocabulary_list=('omar', 'stringer', 'marlo'))
+
def test_defaults_int(self):
column = fc.categorical_column_with_vocabulary_list(
key='aaa', vocabulary_list=(12, 24, 36))
@@ -4143,6 +4161,10 @@ class IdentityCategoricalColumnTest(test.TestCase):
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, column._parse_example_spec)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_identity(key=('aaa',), num_buckets=3)
+
def test_deep_copy(self):
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
for column in (original, copy.deepcopy(original)):