diff options
author | 2017-05-05 09:04:16 -0800 | |
---|---|---|
committer | 2017-05-05 10:24:25 -0700 | |
commit | b329dd821e29e64c93b1b9bf38e61871c6cb53da (patch) | |
tree | 5da6ddb59a5d456fa3aa9abe75dcf5af732f7edc /tensorflow | |
parent | dd63839973e407b4e1501d2dd146e54ee30e445b (diff) |
Add `categorical_column_with_identity`.
Change: 155209179
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/feature_column/feature_column.py | 102 | ||||
-rw-r--r-- | tensorflow/python/feature_column/feature_column_test.py | 194 |
2 files changed, 296 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index ffdf8868e2..f8855f259e 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -129,6 +129,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -656,6 +657,44 @@ def categorical_column_with_vocabulary_list( default_value=default_value) +def categorical_column_with_identity(key, num_buckets, default_value=None): + """A `_CategoricalColumn` that returns identity values. + + Use this when your inputs are integers in the range `[0, num_buckets)`. Values + outside this range will result in `default_value` if specified, otherwise it + will fail. + + Inputs can be either `Tensor` or `SparseTensor`. + ``` + + Args: + key: A unique string identifying the input feature. It is used as the + column name and the dictionary key for feature parsing configs, feature + `Tensor` objects, and feature columns. + num_buckets: Range of inputs and outputs is `[0, num_buckets)`. + default_value: If `None`, this column's graph operations will fail for + out-of-range inputs. Otherwise, this value must be in the range + `[0, num_buckets)`, and will replace inputs in that range. + + Returns: + A `_CategoricalColumn` that returns identity values. + + Raises: + ValueError: if `num_buckets` is less than one. + ValueError: if `default_value` is not in range `[0, num_buckets)`. + """ + if num_buckets < 1: + raise ValueError( + 'num_buckets {} < 1, column_name {}'.format(num_buckets, key)) + if (default_value is not None) and ( + (default_value < 0) or (default_value >= num_buckets)): + raise ValueError( + 'default_value {} not in range [0, {}), column_name {}'.format( + default_value, num_buckets, key)) + return _IdentityCategoricalColumn( + key=key, num_buckets=num_buckets, default_value=default_value) + + class _FeatureColumn(object): """Represents a feature column abstraction. @@ -1384,6 +1423,69 @@ class _VocabularyListCategoricalColumn( return _CategoricalColumn.IdWeightPair(inputs.get(self), None) +class _IdentityCategoricalColumn( + _CategoricalColumn, + collections.namedtuple('_IdentityCategoricalColumn', ( + 'key', 'num_buckets', 'default_value' + ))): + + """See `categorical_column_with_identity`.""" + + @property + def name(self): + return self.key + + @property + def _parse_example_config(self): + return {self.key: parsing_ops.VarLenFeature(dtypes.int64)} + + def _transform_feature(self, inputs): + input_tensor = _to_sparse_input(inputs.get(self.key)) + + if not input_tensor.dtype.is_integer: + raise ValueError( + 'Invalid input, not integer. key: {} dtype: {}'.format( + self.key, input_tensor.dtype)) + + values = math_ops.to_int64(input_tensor.values, name='values') + num_buckets = math_ops.to_int64(self.num_buckets, name='num_buckets') + zero = math_ops.to_int64(0, name='zero') + if self.default_value is None: + # Fail if values are out-of-range. + assert_less = check_ops.assert_less( + values, num_buckets, data=(values, num_buckets), + name='assert_less_than_num_buckets') + assert_greater = check_ops.assert_greater_equal( + values, zero, data=(values,), + name='assert_greater_or_equal_0') + with ops.control_dependencies((assert_less, assert_greater)): + values = array_ops.identity(values) + else: + # Assign default for out-of-range values. + values = array_ops.where( + math_ops.logical_or( + values < zero, values >= num_buckets, name='out_of_range'), + array_ops.fill( + dims=array_ops.shape(values), + value=math_ops.to_int64(self.default_value), + name='default_values'), + values) + + return sparse_tensor_lib.SparseTensor( + indices=input_tensor.indices, + values=values, + dense_shape=input_tensor.dense_shape) + + @property + def _num_buckets(self): + """Returns number of buckets in this sparse feature.""" + return self.num_buckets + + def _get_sparse_tensors( + self, inputs, weight_collections=None, trainable=None): + return _CategoricalColumn.IdWeightPair(inputs.get(self), None) + + # TODO(zakaria): Move this to embedding_ops and make it public. def _safe_embedding_lookup_sparse(embedding_weights, sparse_ids, diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index 59aa39411f..5201811831 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variable_scope @@ -1828,5 +1829,198 @@ class VocabularyListCategoricalColumnTest(test.TestCase): self.assertAllClose(((3.,), (1.,)), predictions.eval()) +class IdentityCategoricalColumnTest(test.TestCase): + + def test_constructor(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_deep_copy(self): + """Tests deepcopy of categorical_column_with_hash_bucket.""" + original = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + for column in (original, copy.deepcopy(original)): + self.assertEqual('aaa', column.name) + # pylint: disable=protected-access + self.assertEqual(3, column._num_buckets) + self.assertEqual({ + 'aaa': parsing_ops.VarLenFeature(dtypes.int64) + }, column._parse_example_config) + # pylint: enable=protected-access + + def test_invalid_num_buckets_zero(self): + with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'): + fc.categorical_column_with_identity(key='aaa', num_buckets=0) + + def test_invalid_num_buckets_negative(self): + with self.assertRaisesRegexp(ValueError, 'num_buckets -1 < 1'): + fc.categorical_column_with_identity(key='aaa', num_buckets=-1) + + def test_invalid_default_value_too_small(self): + with self.assertRaisesRegexp(ValueError, 'default_value -1 not in range'): + fc.categorical_column_with_identity( + key='aaa', num_buckets=3, default_value=-1) + + def test_invalid_default_value_too_big(self): + with self.assertRaisesRegexp(ValueError, 'default_value 3 not in range'): + fc.categorical_column_with_identity( + key='aaa', num_buckets=3, default_value=3) + + def test_invalid_input_dtype(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=('omar', 'stringer', 'marlo'), + dense_shape=(2, 2)) + with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'): + # pylint: disable=protected-access + column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + + def test_get_sparse_tensors(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 1, 0), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((0, 1, 0), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_dense_input(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({ + 'aaa': ((0, -1), (1, 0)) + })) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=np.array((0, 1, 0), dtype=np.int64), + dense_shape=(2, 2)), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_with_inputs_too_small(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, -1, 0), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + with self.assertRaisesRegexp( + errors.OpError, 'assert_greater_or_equal_0'): + id_weight_pair.id_tensor.eval() + + def test_get_sparse_tensors_with_inputs_too_big(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 99, 0), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + with self.assertRaisesRegexp( + errors.OpError, 'assert_less_than_num_buckets'): + id_weight_pair.id_tensor.eval() + + def test_get_sparse_tensors_with_default_value(self): + column = fc.categorical_column_with_identity( + key='aaa', num_buckets=4, default_value=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, -1, 99), + dense_shape=(2, 2)) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=inputs.indices, + values=np.array((1, 3, 3), dtype=np.int64), + dense_shape=inputs.dense_shape), + id_weight_pair.id_tensor.eval()) + + def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self): + column = fc.categorical_column_with_identity( + key='aaa', num_buckets=4, default_value=3) + input_indices = array_ops.placeholder(dtype=dtypes.int64) + input_values = array_ops.placeholder(dtype=dtypes.int32) + input_shape = array_ops.placeholder(dtype=dtypes.int64) + inputs = sparse_tensor.SparseTensorValue( + indices=input_indices, + values=input_values, + dense_shape=input_shape) + # pylint: disable=protected-access + id_weight_pair = column._get_sparse_tensors( + fc._LazyBuilder({'aaa': inputs})) + # pylint: enable=protected-access + self.assertIsNone(id_weight_pair.weight_tensor) + with _initialized_session(): + _assert_sparse_tensor_value( + self, + sparse_tensor.SparseTensorValue( + indices=np.array(((0, 0), (1, 0), (1, 1)), dtype=np.int64), + values=np.array((1, 3, 3), dtype=np.int64), + dense_shape=np.array((2, 2), dtype=np.int64)), + id_weight_pair.id_tensor.eval(feed_dict={ + input_indices: ((0, 0), (1, 0), (1, 1)), + input_values: (1, -1, 99), + input_shape: (2, 2), + })) + + def test_make_linear_model(self): + column = fc.categorical_column_with_identity(key='aaa', num_buckets=3) + self.assertEqual(3, column._num_buckets) + with ops.Graph().as_default(): + predictions = fc.make_linear_model({ + column.name: sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(0, 2, 1), + dense_shape=(2, 2)) + }, (column,)) + bias = get_linear_model_bias() + weight_var = get_linear_model_column_var(column) + with _initialized_session(): + self.assertAllClose((0.,), bias.eval()) + self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval()) + self.assertAllClose(((0.,), (0.,)), predictions.eval()) + weight_var.assign(((1.,), (2.,), (3.,))).eval() + # weight_var[0] = 1 + # weight_var[2] + weight_var[1] = 3+2 = 5 + self.assertAllClose(((1.,), (5.,)), predictions.eval()) + + if __name__ == '__main__': test.main() |