aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-05 09:04:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-05 10:24:25 -0700
commitb329dd821e29e64c93b1b9bf38e61871c6cb53da (patch)
tree5da6ddb59a5d456fa3aa9abe75dcf5af732f7edc /tensorflow
parentdd63839973e407b4e1501d2dd146e54ee30e445b (diff)
Add `categorical_column_with_identity`.
Change: 155209179
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/feature_column/feature_column.py102
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py194
2 files changed, 296 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index ffdf8868e2..f8855f259e 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -129,6 +129,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -656,6 +657,44 @@ def categorical_column_with_vocabulary_list(
default_value=default_value)
+def categorical_column_with_identity(key, num_buckets, default_value=None):
+ """A `_CategoricalColumn` that returns identity values.
+
+ Use this when your inputs are integers in the range `[0, num_buckets)`. Values
+ outside this range will result in `default_value` if specified, otherwise it
+ will fail.
+
+ Inputs can be either `Tensor` or `SparseTensor`.
+ ```
+
+ Args:
+ key: A unique string identifying the input feature. It is used as the
+ column name and the dictionary key for feature parsing configs, feature
+ `Tensor` objects, and feature columns.
+ num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
+ default_value: If `None`, this column's graph operations will fail for
+ out-of-range inputs. Otherwise, this value must be in the range
+ `[0, num_buckets)`, and will replace inputs in that range.
+
+ Returns:
+ A `_CategoricalColumn` that returns identity values.
+
+ Raises:
+ ValueError: if `num_buckets` is less than one.
+ ValueError: if `default_value` is not in range `[0, num_buckets)`.
+ """
+ if num_buckets < 1:
+ raise ValueError(
+ 'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
+ if (default_value is not None) and (
+ (default_value < 0) or (default_value >= num_buckets)):
+ raise ValueError(
+ 'default_value {} not in range [0, {}), column_name {}'.format(
+ default_value, num_buckets, key))
+ return _IdentityCategoricalColumn(
+ key=key, num_buckets=num_buckets, default_value=default_value)
+
+
class _FeatureColumn(object):
"""Represents a feature column abstraction.
@@ -1384,6 +1423,69 @@ class _VocabularyListCategoricalColumn(
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
+class _IdentityCategoricalColumn(
+ _CategoricalColumn,
+ collections.namedtuple('_IdentityCategoricalColumn', (
+ 'key', 'num_buckets', 'default_value'
+ ))):
+
+ """See `categorical_column_with_identity`."""
+
+ @property
+ def name(self):
+ return self.key
+
+ @property
+ def _parse_example_config(self):
+ return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
+
+ def _transform_feature(self, inputs):
+ input_tensor = _to_sparse_input(inputs.get(self.key))
+
+ if not input_tensor.dtype.is_integer:
+ raise ValueError(
+ 'Invalid input, not integer. key: {} dtype: {}'.format(
+ self.key, input_tensor.dtype))
+
+ values = math_ops.to_int64(input_tensor.values, name='values')
+ num_buckets = math_ops.to_int64(self.num_buckets, name='num_buckets')
+ zero = math_ops.to_int64(0, name='zero')
+ if self.default_value is None:
+ # Fail if values are out-of-range.
+ assert_less = check_ops.assert_less(
+ values, num_buckets, data=(values, num_buckets),
+ name='assert_less_than_num_buckets')
+ assert_greater = check_ops.assert_greater_equal(
+ values, zero, data=(values,),
+ name='assert_greater_or_equal_0')
+ with ops.control_dependencies((assert_less, assert_greater)):
+ values = array_ops.identity(values)
+ else:
+ # Assign default for out-of-range values.
+ values = array_ops.where(
+ math_ops.logical_or(
+ values < zero, values >= num_buckets, name='out_of_range'),
+ array_ops.fill(
+ dims=array_ops.shape(values),
+ value=math_ops.to_int64(self.default_value),
+ name='default_values'),
+ values)
+
+ return sparse_tensor_lib.SparseTensor(
+ indices=input_tensor.indices,
+ values=values,
+ dense_shape=input_tensor.dense_shape)
+
+ @property
+ def _num_buckets(self):
+ """Returns number of buckets in this sparse feature."""
+ return self.num_buckets
+
+ def _get_sparse_tensors(
+ self, inputs, weight_collections=None, trainable=None):
+ return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
+
+
# TODO(zakaria): Move this to embedding_ops and make it public.
def _safe_embedding_lookup_sparse(embedding_weights,
sparse_ids,
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 59aa39411f..5201811831 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variable_scope
@@ -1828,5 +1829,198 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
self.assertAllClose(((3.,), (1.,)), predictions.eval())
+class IdentityCategoricalColumnTest(test.TestCase):
+
+ def test_constructor(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual('aaa', column.name)
+ # pylint: disable=protected-access
+ self.assertEqual(3, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_deep_copy(self):
+ """Tests deepcopy of categorical_column_with_hash_bucket."""
+ original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ for column in (original, copy.deepcopy(original)):
+ self.assertEqual('aaa', column.name)
+ # pylint: disable=protected-access
+ self.assertEqual(3, column._num_buckets)
+ self.assertEqual({
+ 'aaa': parsing_ops.VarLenFeature(dtypes.int64)
+ }, column._parse_example_config)
+ # pylint: enable=protected-access
+
+ def test_invalid_num_buckets_zero(self):
+ with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'):
+ fc.categorical_column_with_identity(key='aaa', num_buckets=0)
+
+ def test_invalid_num_buckets_negative(self):
+ with self.assertRaisesRegexp(ValueError, 'num_buckets -1 < 1'):
+ fc.categorical_column_with_identity(key='aaa', num_buckets=-1)
+
+ def test_invalid_default_value_too_small(self):
+ with self.assertRaisesRegexp(ValueError, 'default_value -1 not in range'):
+ fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3, default_value=-1)
+
+ def test_invalid_default_value_too_big(self):
+ with self.assertRaisesRegexp(ValueError, 'default_value 3 not in range'):
+ fc.categorical_column_with_identity(
+ key='aaa', num_buckets=3, default_value=3)
+
+ def test_invalid_input_dtype(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2))
+ with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
+ # pylint: disable=protected-access
+ column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+
+ def test_get_sparse_tensors(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 1, 0),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_dense_input(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
+ 'aaa': ((0, -1), (1, 0))
+ }))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=np.array((0, 1, 0), dtype=np.int64),
+ dense_shape=(2, 2)),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_inputs_too_small(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, -1, 0),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ with self.assertRaisesRegexp(
+ errors.OpError, 'assert_greater_or_equal_0'):
+ id_weight_pair.id_tensor.eval()
+
+ def test_get_sparse_tensors_with_inputs_too_big(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 99, 0),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ with self.assertRaisesRegexp(
+ errors.OpError, 'assert_less_than_num_buckets'):
+ id_weight_pair.id_tensor.eval()
+
+ def test_get_sparse_tensors_with_default_value(self):
+ column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=4, default_value=3)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, -1, 99),
+ dense_shape=(2, 2))
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=inputs.indices,
+ values=np.array((1, 3, 3), dtype=np.int64),
+ dense_shape=inputs.dense_shape),
+ id_weight_pair.id_tensor.eval())
+
+ def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
+ column = fc.categorical_column_with_identity(
+ key='aaa', num_buckets=4, default_value=3)
+ input_indices = array_ops.placeholder(dtype=dtypes.int64)
+ input_values = array_ops.placeholder(dtype=dtypes.int32)
+ input_shape = array_ops.placeholder(dtype=dtypes.int64)
+ inputs = sparse_tensor.SparseTensorValue(
+ indices=input_indices,
+ values=input_values,
+ dense_shape=input_shape)
+ # pylint: disable=protected-access
+ id_weight_pair = column._get_sparse_tensors(
+ fc._LazyBuilder({'aaa': inputs}))
+ # pylint: enable=protected-access
+ self.assertIsNone(id_weight_pair.weight_tensor)
+ with _initialized_session():
+ _assert_sparse_tensor_value(
+ self,
+ sparse_tensor.SparseTensorValue(
+ indices=np.array(((0, 0), (1, 0), (1, 1)), dtype=np.int64),
+ values=np.array((1, 3, 3), dtype=np.int64),
+ dense_shape=np.array((2, 2), dtype=np.int64)),
+ id_weight_pair.id_tensor.eval(feed_dict={
+ input_indices: ((0, 0), (1, 0), (1, 1)),
+ input_values: (1, -1, 99),
+ input_shape: (2, 2),
+ }))
+
+ def test_make_linear_model(self):
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
+ self.assertEqual(3, column._num_buckets)
+ with ops.Graph().as_default():
+ predictions = fc.make_linear_model({
+ column.name: sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2))
+ }, (column,))
+ bias = get_linear_model_bias()
+ weight_var = get_linear_model_column_var(column)
+ with _initialized_session():
+ self.assertAllClose((0.,), bias.eval())
+ self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
+ self.assertAllClose(((0.,), (0.,)), predictions.eval())
+ weight_var.assign(((1.,), (2.,), (3.,))).eval()
+ # weight_var[0] = 1
+ # weight_var[2] + weight_var[1] = 3+2 = 5
+ self.assertAllClose(((1.,), (5.,)), predictions.eval())
+
+
if __name__ == '__main__':
test.main()