aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-09 09:30:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-09 10:51:28 -0700
commit6173228eb19b8a41a71e9d4d046238c6ee990351 (patch)
treebf36765583c3434471baef651272af14391ced25
parentb0ce8deae4f8b0b24c8d8e18c4f62c3b1927f9d8 (diff)
Support integer sparse_column_with_hash_bucket.
Change: 132689168
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py34
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py29
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_test.py15
3 files changed, 68 insertions, 10 deletions
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index c8f316650a..fa257a38fb 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -443,31 +443,45 @@ def sparse_column_with_integerized_feature(column_name,
class _SparseColumnHashed(_SparseColumn):
"""See `sparse_column_with_hash_bucket`."""
- def __new__(cls, column_name, hash_bucket_size, combiner="sum"):
+ def __new__(cls,
+ column_name,
+ hash_bucket_size,
+ combiner="sum",
+ dtype=dtypes.string):
+
+ if dtype != dtypes.string and not dtype.is_integer:
+ raise ValueError("dtype must be string or integer. "
+ "dtype: {}, column_name: {}".format(dtype, column_name))
return super(_SparseColumnHashed, cls).__new__(
cls,
column_name,
bucket_size=hash_bucket_size,
combiner=combiner,
- dtype=dtypes.string)
+ dtype=dtype)
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
+ sparse_tensor = columns_to_tensors[self.name]
+ if self.dtype.is_integer:
+ sparse_values = string_ops.as_string(sparse_tensor.values)
+ else:
+ sparse_values = sparse_tensor.values
+
sparse_id_values = string_ops.string_to_hash_bucket_fast(
- columns_to_tensors[self.name].values, self.bucket_size, name="lookup")
+ sparse_values, self.bucket_size, name="lookup")
columns_to_tensors[self] = ops.SparseTensor(
- columns_to_tensors[self.name].indices, sparse_id_values,
- columns_to_tensors[self.name].shape)
+ sparse_tensor.indices, sparse_id_values, sparse_tensor.shape)
def sparse_column_with_hash_bucket(column_name,
hash_bucket_size,
- combiner="sum"):
+ combiner="sum",
+ dtype=dtypes.string):
"""Creates a _SparseColumn with hashed bucket configuration.
- Use this when your sparse features are in string format, but you don't have a
- vocab file that maps each string to an integer ID.
+ Use this when your sparse features are in string or integer format, but you
+ don't have a vocab file that maps each value to an integer ID.
output_id = Hash(input_feature_string) % bucket_size
Args:
@@ -480,14 +494,16 @@ def sparse_column_with_hash_bucket(column_name,
* "mean": do l1 normalization on features in the column
* "sqrtn": do l2 normalization on features in the column
For more information: `tf.embedding_lookup_sparse`.
+ dtype: The type of features. Only string and integer types are supported.
Returns:
A _SparseColumn with hashed bucket configuration
Raises:
ValueError: hash_bucket_size is not greater than 2.
+ ValueError: dtype is neither string nor integer.
"""
- return _SparseColumnHashed(column_name, hash_bucket_size, combiner)
+ return _SparseColumnHashed(column_name, hash_bucket_size, combiner, dtype)
class _SparseColumnKeys(_SparseColumn):
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index 519a769d34..f79327cc9b 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -82,6 +82,21 @@ class TransformerTest(tf.test.TestCase):
self.assertAllEqual(output.indices.eval(), wire_tensor.indices.eval())
self.assertAllEqual(output.shape.eval(), wire_tensor.shape.eval())
+ def testSparseIntColumnWithHashBucket(self):
+ """Tests a sparse column with int values."""
+ hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket(
+ "wire", 10, dtype=tf.int64)
+ wire_tensor = tf.SparseTensor(values=[101, 201, 301],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ shape=[2, 2])
+ features = {"wire": wire_tensor}
+ output = feature_column_ops._Transformer(features).transform(hashed_sparse)
+ with self.test_session():
+ self.assertEqual(output.values.dtype, tf.int64)
+ self.assertTrue(all(x < 10 and x >= 0 for x in output.values.eval()))
+ self.assertAllEqual(output.indices.eval(), wire_tensor.indices.eval())
+ self.assertAllEqual(output.shape.eval(), wire_tensor.shape.eval())
+
def testEmbeddingColumn(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],
@@ -721,6 +736,20 @@ class WeightedSumTest(tf.test.TestCase):
tf.initialize_all_variables().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
+ def testSparseIntColumn(self):
+ """Tests a sparse column with int values."""
+ hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket(
+ "wire", 10, dtype=tf.int64)
+ wire_tensor = tf.SparseTensor(values=[101, 201, 301],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ shape=[2, 2])
+ features = {"wire": wire_tensor}
+ logits, _, _ = tf.contrib.layers.weighted_sum_from_feature_columns(
+ features, [hashed_sparse], num_outputs=5)
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ self.assertAllEqual(logits.eval().shape, [2, 5])
+
def testWeightedSparseColumn(self):
ids = tf.contrib.layers.sparse_column_with_keys(
"ids", ["marlo", "omar", "stringer"])
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py
index cf31acfd86..d43f0b2353 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py
@@ -32,10 +32,23 @@ class FeatureColumnTest(tf.test.TestCase):
with self.assertRaises(AttributeError):
a.column_name = "bbb"
- def testSparseColumn(self):
+ def testSparseColumnWithHashBucket(self):
a = tf.contrib.layers.sparse_column_with_hash_bucket("aaa",
hash_bucket_size=100)
self.assertEqual(a.name, "aaa")
+ self.assertEqual(a.dtype, tf.string)
+
+ a = tf.contrib.layers.sparse_column_with_hash_bucket("aaa",
+ hash_bucket_size=100,
+ dtype=tf.int64)
+ self.assertEqual(a.name, "aaa")
+ self.assertEqual(a.dtype, tf.int64)
+
+ with self.assertRaisesRegexp(ValueError,
+ "dtype must be string or integer"):
+ a = tf.contrib.layers.sparse_column_with_hash_bucket("aaa",
+ hash_bucket_size=100,
+ dtype=tf.float32)
def testWeightedSparseColumn(self):
ids = tf.contrib.layers.sparse_column_with_keys(