diff options
author | Mustafa Ispir <ispir@google.com> | 2018-05-14 22:04:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-14 22:07:42 -0700 |
commit | 2869a86c56b163318cfb47126f3c7f56db0b642c (patch) | |
tree | 364cf310e8dd686bae720f1c413a3cd3d5f97811 /tensorflow/python/feature_column | |
parent | fdf36165090f465cc2464de26c939237c45155a3 (diff) |
Added type check to feature column keys. So that users will get meaningful error messages in situations like: #19219
PiperOrigin-RevId: 196616638
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 12 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_test.py | 22 |
2 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index c16c3cda48..1d50892a88 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -1068,6 +1068,7 @@ def numeric_column(key, raise TypeError( 'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn)) + _assert_key_is_string(key) return _NumericColumn( key, shape=shape, @@ -1166,6 +1167,13 @@ def _assert_string_or_int(dtype, prefix): '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype)) +def _assert_key_is_string(key): + if not isinstance(key, six.string_types): + raise ValueError( + 'key must be a string. Got: type {}. Given key: {}.'.format( + type(key), key)) + + @tf_export('feature_column.categorical_column_with_hash_bucket') def categorical_column_with_hash_bucket(key, hash_bucket_size, @@ -1218,6 +1226,7 @@ def categorical_column_with_hash_bucket(key, 'hash_bucket_size: {}, key: {}'.format( hash_bucket_size, key)) + _assert_key_is_string(key) _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) return _HashedCategoricalColumn(key, hash_bucket_size, dtype) @@ -1334,6 +1343,7 @@ def categorical_column_with_vocabulary_file(key, raise ValueError('Invalid num_oov_buckets {} in {}.'.format( num_oov_buckets, key)) _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + _assert_key_is_string(key) return _VocabularyFileCategoricalColumn( key=key, vocabulary_file=vocabulary_file, @@ -1448,6 +1458,7 @@ def categorical_column_with_vocabulary_list( 'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format( dtype, vocabulary_dtype, key)) _assert_string_or_int(dtype, prefix='column_name: {}'.format(key)) + _assert_key_is_string(key) return _VocabularyListCategoricalColumn( key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype, @@ -1518,6 +1529,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None): raise ValueError( 'default_value {} not in range [0, {}), column_name {}'.format( default_value, num_buckets, key)) + _assert_key_is_string(key) return _IdentityCategoricalColumn( key=key, num_buckets=num_buckets, default_value=default_value) diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index b06540489f..03c47eea31 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -182,6 +182,10 @@ class NumericColumnTest(test.TestCase): self.assertEqual(dtypes.float32, a.dtype) self.assertIsNone(a.normalizer_fn) + def test_key_should_be_string(self): + with self.assertRaisesRegexp(ValueError, 'key must be a string.'): + fc.numeric_column(key=('aaa',)) + def test_shape_saved_as_tuple(self): a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]]) self.assertEqual((1, 2), a.shape) @@ -645,6 +649,10 @@ class HashedCategoricalColumnTest(test.TestCase): self.assertEqual(10, a.hash_bucket_size) self.assertEqual(dtypes.string, a.dtype) + def test_key_should_be_string(self): + with self.assertRaisesRegexp(ValueError, 'key must be a string.'): + fc.categorical_column_with_hash_bucket(('key',), 10) + def test_bucket_size_should_be_given(self): with self.assertRaisesRegexp(ValueError, 'hash_bucket_size must be set.'): fc.categorical_column_with_hash_bucket('aaa', None) @@ -3327,6 +3335,11 @@ class VocabularyFileCategoricalColumnTest(test.TestCase): 'aaa': parsing_ops.VarLenFeature(dtypes.string) }, column._parse_example_spec) + def test_key_should_be_string(self): + with self.assertRaisesRegexp(ValueError, 'key must be a string.'): + fc.categorical_column_with_vocabulary_file( + key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3) + def test_all_constructor_args(self): column = fc.categorical_column_with_vocabulary_file( key='aaa', vocabulary_file='path_to_file', vocabulary_size=3, @@ -3752,6 +3765,11 @@ class VocabularyListCategoricalColumnTest(test.TestCase): 'aaa': parsing_ops.VarLenFeature(dtypes.string) }, column._parse_example_spec) + def test_key_should_be_string(self): + with self.assertRaisesRegexp(ValueError, 'key must be a string.'): + fc.categorical_column_with_vocabulary_list( + key=('aaa',), vocabulary_list=('omar', 'stringer', 'marlo')) + def test_defaults_int(self): column = fc.categorical_column_with_vocabulary_list( key='aaa', vocabulary_list=(12, 24, 36)) @@ -4143,6 +4161,10 @@ class IdentityCategoricalColumnTest(test.TestCase): 'aaa': parsing_ops.VarLenFeature(dtypes.int64) }, column._parse_example_spec) + def test_key_should_be_string(self): + with self.assertRaisesRegexp(ValueError, 'key must be a string.'): + fc.categorical_column_with_identity(key=('aaa',), num_buckets=3) + def test_deep_copy(self): original = fc.categorical_column_with_identity(key='aaa', num_buckets=3) for column in (original, copy.deepcopy(original)): |