diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-22 11:12:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-22 11:14:40 -0700 |
commit | 1a6752dddf387d280a6a13c2dc7e2bebf69dab2f (patch) | |
tree | 7c88d9d18b4baade07f96156dd6add0927275fc7 /tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py | |
parent | 6fa811a94f3da0c49d69db9b15ea424f84a6431f (diff) |
Adds remaining validations in sequence_numeric_column.
PiperOrigin-RevId: 190094883
Diffstat (limited to 'tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py')
-rw-r--r-- | tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py index b64f086376..88f5d53516 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py @@ -662,6 +662,32 @@ class SequenceIndicatorColumnTest(test.TestCase): class SequenceNumericColumnTest(test.TestCase): + def test_defaults(self): + a = sfc.sequence_numeric_column('aaa') + self.assertEqual('aaa', a.key) + self.assertEqual('aaa', a.name) + self.assertEqual('aaa', a._var_scope_name) + self.assertEqual((1,), a.shape) + self.assertEqual(0., a.default_value) + self.assertEqual(dtypes.float32, a.dtype) + + def test_shape_saved_as_tuple(self): + a = sfc.sequence_numeric_column('aaa', shape=[1, 2]) + self.assertEqual((1, 2), a.shape) + + def test_shape_must_be_positive_integer(self): + with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'): + sfc.sequence_numeric_column('aaa', shape=[1.0]) + + with self.assertRaisesRegexp( + ValueError, 'shape dimensions must be greater than 0'): + sfc.sequence_numeric_column('aaa', shape=[0]) + + def test_dtype_is_convertible_to_float(self): + with self.assertRaisesRegexp( + ValueError, 'dtype must be convertible to float'): + sfc.sequence_numeric_column('aaa', dtype=dtypes.string) + def test_get_sequence_dense_tensor(self): sparse_input = sparse_tensor.SparseTensorValue( # example 0, values [[0.], [1]] |