aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/feature_column
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-22 11:12:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-22 11:14:40 -0700
commit1a6752dddf387d280a6a13c2dc7e2bebf69dab2f (patch)
tree7c88d9d18b4baade07f96156dd6add0927275fc7 /tensorflow/contrib/feature_column
parent6fa811a94f3da0c49d69db9b15ea424f84a6431f (diff)
Adds remaining validations in sequence_numeric_column.
PiperOrigin-RevId: 190094883
Diffstat (limited to 'tensorflow/contrib/feature_column')
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py32
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py26
2 files changed, 57 insertions, 1 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
index e60116966f..555beddeaa 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
@@ -166,6 +166,10 @@ def sequence_categorical_column_with_identity(
Returns:
A `_SequenceCategoricalColumn`.
+
+ Raises:
+ ValueError: if `num_buckets` is less than one.
+ ValueError: if `default_value` is not in range `[0, num_buckets)`.
"""
return fc._SequenceCategoricalColumn(
fc.categorical_column_with_identity(
@@ -205,6 +209,10 @@ def sequence_categorical_column_with_hash_bucket(
Returns:
A `_SequenceCategoricalColumn`.
+
+ Raises:
+ ValueError: `hash_bucket_size` is not greater than 1.
+ ValueError: `dtype` is neither string nor integer.
"""
return fc._SequenceCategoricalColumn(
fc.categorical_column_with_hash_bucket(
@@ -257,6 +265,13 @@ def sequence_categorical_column_with_vocabulary_file(
Returns:
A `_SequenceCategoricalColumn`.
+
+ Raises:
+ ValueError: `vocabulary_file` is missing or cannot be opened.
+ ValueError: `vocabulary_size` is missing or < 1.
+ ValueError: `num_oov_buckets` is a negative integer.
+ ValueError: `num_oov_buckets` and `default_value` are both specified.
+ ValueError: `dtype` is neither string nor integer.
"""
return fc._SequenceCategoricalColumn(
fc.categorical_column_with_vocabulary_file(
@@ -311,6 +326,12 @@ def sequence_categorical_column_with_vocabulary_list(
Returns:
A `_SequenceCategoricalColumn`.
+
+ Raises:
+ ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
+ ValueError: `num_oov_buckets` is a negative integer.
+ ValueError: `num_oov_buckets` and `default_value` are both specified.
+ ValueError: if `dtype` is not integer or string.
"""
return fc._SequenceCategoricalColumn(
fc.categorical_column_with_vocabulary_list(
@@ -352,8 +373,17 @@ def sequence_numeric_column(
Returns:
A `_SequenceNumericColumn`.
+
+ Raises:
+ TypeError: if any dimension in shape is not an int.
+ ValueError: if any dimension in shape is not a positive integer.
+ ValueError: if `dtype` is not convertible to `tf.float32`.
"""
- # TODO(b/73160931): Add validations.
+ shape = fc._check_shape(shape=shape, key=key)
+ if not (dtype.is_integer or dtype.is_floating):
+ raise ValueError('dtype must be convertible to float. '
+ 'dtype: {}, key: {}'.format(dtype, key))
+
return _SequenceNumericColumn(
key,
shape=shape,
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index b64f086376..88f5d53516 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -662,6 +662,32 @@ class SequenceIndicatorColumnTest(test.TestCase):
class SequenceNumericColumnTest(test.TestCase):
+ def test_defaults(self):
+ a = sfc.sequence_numeric_column('aaa')
+ self.assertEqual('aaa', a.key)
+ self.assertEqual('aaa', a.name)
+ self.assertEqual('aaa', a._var_scope_name)
+ self.assertEqual((1,), a.shape)
+ self.assertEqual(0., a.default_value)
+ self.assertEqual(dtypes.float32, a.dtype)
+
+ def test_shape_saved_as_tuple(self):
+ a = sfc.sequence_numeric_column('aaa', shape=[1, 2])
+ self.assertEqual((1, 2), a.shape)
+
+ def test_shape_must_be_positive_integer(self):
+ with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'):
+ sfc.sequence_numeric_column('aaa', shape=[1.0])
+
+ with self.assertRaisesRegexp(
+ ValueError, 'shape dimensions must be greater than 0'):
+ sfc.sequence_numeric_column('aaa', shape=[0])
+
+ def test_dtype_is_convertible_to_float(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'dtype must be convertible to float'):
+ sfc.sequence_numeric_column('aaa', dtype=dtypes.string)
+
def test_get_sequence_dense_tensor(self):
sparse_input = sparse_tensor.SparseTensorValue(
# example 0, values [[0.], [1]]