aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-14 12:51:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-14 13:04:25 -0800
commite34561fc3d10985e74cdda36b49be8742a928a22 (patch)
tree89117ef7a2e57f0f8d5468ecf90445e5221321a9
parentc40efbdd377f157297c37d5f2d3fec45e775fdb0 (diff)
Adding support for dense tensors as input to sparse_column_* methods.
Change: 139108615
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py46
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py110
2 files changed, 146 insertions, 10 deletions
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index 2365a06668..9d1a99e649 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -119,11 +119,13 @@ from __future__ import print_function
import abc
import collections
import math
+
import six
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.layers.python.ops import bucketization_op
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
+from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
from tensorflow.contrib.lookup import lookup_ops as contrib_lookup_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
@@ -400,6 +402,22 @@ class _SparseColumn(_FeatureColumn,
initializer=init_ops.zeros_initializer,
combiner=self.combiner)
+ def _get_input_sparse_tensor(self, columns_to_tensors):
+ """Looks up the input tensor for transformation and sparsify it if dense."""
+ input_tensor = columns_to_tensors[self.name]
+ if not isinstance(input_tensor, sparse_tensor_py.SparseTensor):
+ # To avoid making any assumptions about which values are to be ignored,
+ # we set ignore_value to -1 for numeric tensors to avoid excluding valid
+ # indices.
+ if input_tensor.dtype == dtypes.string:
+ ignore_value = ""
+ else:
+ ignore_value = -1
+ input_tensor = contrib_sparse_ops.dense_to_sparse_tensor(
+ input_tensor, ignore_value=ignore_value)
+
+ return input_tensor
+
def is_compatible(self, other_column):
"""Check compatability of two sparse columns."""
if self.lookup_config and other_column.lookup_config:
@@ -432,12 +450,12 @@ class _SparseColumnIntegerized(_SparseColumn):
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
- sparse_id_values = math_ops.mod(columns_to_tensors[self.name].values,
- self.bucket_size,
+ input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
+
+ sparse_id_values = math_ops.mod(input_tensor.values, self.bucket_size,
name="mod")
columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
- columns_to_tensors[self.name].indices, sparse_id_values,
- columns_to_tensors[self.name].shape)
+ input_tensor.indices, sparse_id_values, input_tensor.shape)
def sparse_column_with_integerized_feature(column_name,
@@ -501,16 +519,17 @@ class _SparseColumnHashed(_SparseColumn):
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
- sparse_tensor = columns_to_tensors[self.name]
+ input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
+
if self.dtype.is_integer:
- sparse_values = string_ops.as_string(sparse_tensor.values)
+ sparse_values = string_ops.as_string(input_tensor.values)
else:
- sparse_values = sparse_tensor.values
+ sparse_values = input_tensor.values
sparse_id_values = string_ops.string_to_hash_bucket_fast(
sparse_values, self.bucket_size, name="lookup")
columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
- sparse_tensor.indices, sparse_id_values, sparse_tensor.shape)
+ input_tensor.indices, sparse_id_values, input_tensor.shape)
def sparse_column_with_hash_bucket(column_name,
@@ -563,8 +582,10 @@ class _SparseColumnKeys(_SparseColumn):
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
+ input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
+
columns_to_tensors[self] = contrib_lookup_ops.string_to_index(
- tensor=columns_to_tensors[self.name],
+ tensor=input_tensor,
mapping=list(self.lookup_config.keys),
default_value=self.lookup_config.default_value,
name="lookup")
@@ -636,9 +657,14 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
"""Inserts a tuple with the id and weight tensors."""
if self.sparse_id_column not in columns_to_tensors:
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
+
+ weight_tensor = columns_to_tensors[self.weight_column_name]
+ if not isinstance(weight_tensor, sparse_tensor_py.SparseTensor):
+ # The weight tensor can be a regular Tensor. In such case, sparsify it.
+ weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor)
columns_to_tensors[self] = tuple([
columns_to_tensors[self.sparse_id_column],
- columns_to_tensors[self.weight_column_name]
+ weight_tensor
])
def id_tensor(self, input_tensor):
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index e31a9fb7c8..342e5ae117 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -118,6 +118,21 @@ class TransformerTest(tf.test.TestCase):
self.assertAllEqual(output[hashed_sparse].shape.eval(),
wire_tensor.shape.eval())
+ def testSparseColumnWithHashBucketWithDenseInputTensor(self):
+ hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
+ wire_tensor = tf.constant([["omar", "stringer"], ["marlo", "rick"]])
+ features = {"wire": wire_tensor}
+ output = feature_column_ops._Transformer(features).transform(hashed_sparse)
+
+ with self.test_session():
+ # While the input is a dense Tensor, the output should be a SparseTensor.
+ self.assertIsInstance(output, tf.SparseTensor)
+ self.assertEqual(output.values.dtype, tf.int64)
+ self.assertTrue(all(x < 10 and x >= 0 for x in output.values.eval()))
+ self.assertAllEqual(output.indices.eval(),
+ [[0, 0], [0, 1], [1, 0], [1, 1]])
+ self.assertAllEqual(output.shape.eval(), [2, 2])
+
def testEmbeddingColumn(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
wire_tensor = tf.SparseTensor(values=["omar", "stringer", "marlo"],
@@ -160,6 +175,24 @@ class TransformerTest(tf.test.TestCase):
self.assertAllEqual(output[keys_sparse].shape.eval(),
wire_tensor.shape.eval())
+ def testSparseColumnWithKeysWithDenseInputTensor(self):
+ keys_sparse = tf.contrib.layers.sparse_column_with_keys(
+ "wire", ["marlo", "omar", "stringer", "rick"])
+ wire_tensor = tf.constant([["omar", "stringer"], ["marlo", "rick"]])
+
+ features = {"wire": wire_tensor}
+ output = feature_column_ops._Transformer(features).transform(keys_sparse)
+
+ with self.test_session():
+ tf.initialize_all_tables().run()
+ # While the input is a dense Tensor, the output should be a SparseTensor.
+ self.assertIsInstance(output, tf.SparseTensor)
+ self.assertEqual(output.dtype, tf.int64)
+ self.assertAllEqual(output.values.eval(), [1, 2, 0, 3])
+ self.assertAllEqual(output.indices.eval(),
+ [[0, 0], [0, 1], [1, 0], [1, 1]])
+ self.assertAllEqual(output.shape.eval(), [2, 2])
+
def testSparseColumnWithHashBucket_IsIntegerized(self):
hashed_sparse = tf.contrib.layers.sparse_column_with_integerized_feature(
"wire", 10)
@@ -181,6 +214,24 @@ class TransformerTest(tf.test.TestCase):
self.assertAllEqual(output[hashed_sparse].shape.eval(),
wire_tensor.shape.eval())
+ def testSparseColumnWithHashBucketWithDenseInputTensor_IsIntegerized(self):
+ hashed_sparse = tf.contrib.layers.sparse_column_with_integerized_feature(
+ "wire", 10)
+ # wire_tensor = tf.SparseTensor(values=[100, 1, 25],
+ # indices=[[0, 0], [1, 0], [1, 1]],
+ # shape=[2, 2])
+ wire_tensor = tf.constant([[100, 0], [1, 25]])
+ features = {"wire": wire_tensor}
+ output = feature_column_ops._Transformer(features).transform(hashed_sparse)
+ with self.test_session():
+ # While the input is a dense Tensor, the output should be a SparseTensor.
+ self.assertIsInstance(output, tf.SparseTensor)
+ self.assertEqual(output.values.dtype, tf.int32)
+ self.assertTrue(all(x < 10 and x >= 0 for x in output.values.eval()))
+ self.assertAllEqual(output.indices.eval(),
+ [[0, 0], [0, 1], [1, 0], [1, 1]])
+ self.assertAllEqual(output.shape.eval(), [2, 2])
+
def testWeightedSparseColumn(self):
ids = tf.contrib.layers.sparse_column_with_keys(
"ids", ["marlo", "omar", "stringer"])
@@ -1130,6 +1181,16 @@ class WeightedSumTest(tf.test.TestCase):
tf.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
+ def testSparseColumnWithDenseInputTensor(self):
+ hashed_sparse = tf.contrib.layers.sparse_column_with_hash_bucket("wire", 10)
+ wire_tensor = tf.constant([["omar", "stringer"], ["marlo", "rick"]])
+ features = {"wire": wire_tensor}
+ logits, _, _ = tf.contrib.layers.weighted_sum_from_feature_columns(
+ features, [hashed_sparse], num_outputs=5)
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ self.assertAllEqual(logits.eval().shape, [2, 5])
+
def testWeightedSparseColumn(self):
ids = tf.contrib.layers.sparse_column_with_keys(
"ids", ["marlo", "omar", "stringer"])
@@ -1149,6 +1210,23 @@ class WeightedSumTest(tf.test.TestCase):
tf.initialize_all_tables().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
+ def testWeightedSparseColumnWithDenseInputTensor(self):
+ ids = tf.contrib.layers.sparse_column_with_keys(
+ "ids", ["marlo", "omar", "stringer", "rick"])
+ ids_tensor = tf.constant([["omar", "stringer"], ["marlo", "rick"]])
+ weighted_ids = tf.contrib.layers.weighted_sparse_column(ids, "weights")
+ weights_tensor = tf.constant([[10.0, 20.0], [30.0, 40.0]])
+
+ features = {"ids": ids_tensor,
+ "weights": weights_tensor}
+ logits, _, _ = tf.contrib.layers.weighted_sum_from_feature_columns(
+ features, [weighted_ids], num_outputs=5)
+
+ with self.test_session():
+ tf.initialize_all_variables().run()
+ tf.initialize_all_tables().run()
+ self.assertAllEqual(logits.eval().shape, [2, 5])
+
def testCrossedColumn(self):
a = tf.contrib.layers.sparse_column_with_hash_bucket("aaa",
hash_bucket_size=100)
@@ -1737,6 +1815,38 @@ class WeightedSumTest(tf.test.TestCase):
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
+ def testIntegerizedColumnWithDenseInputTensor(self):
+ product = tf.contrib.layers.sparse_column_with_integerized_feature(
+ "product", bucket_size=5)
+ with tf.Graph().as_default():
+ features = {"product": tf.constant([[0], [4], [2]])}
+ output, column_to_variable, _ = (
+ tf.contrib.layers.weighted_sum_from_feature_columns(features,
+ [product],
+ num_outputs=1))
+ with self.test_session() as sess:
+ tf.initialize_all_variables().run()
+ tf.initialize_all_tables().run()
+ product_weights = column_to_variable[product][0]
+ sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
+ self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
+
+ def testIntegerizedColumnWithDenseInputTensor2(self):
+ product = tf.contrib.layers.sparse_column_with_integerized_feature(
+ "product", bucket_size=5)
+ with tf.Graph().as_default():
+ features = {"product": tf.constant([[0, 4], [2, 3]])}
+ output, column_to_variable, _ = (
+ tf.contrib.layers.weighted_sum_from_feature_columns(features,
+ [product],
+ num_outputs=1))
+ with self.test_session() as sess:
+ tf.initialize_all_variables().run()
+ tf.initialize_all_tables().run()
+ product_weights = column_to_variable[product][0]
+ sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
+ self.assertAllClose(output.eval(), [[0.6], [0.7]])
+
def testIntegerizedColumnWithInvalidId(self):
product = tf.contrib.layers.sparse_column_with_integerized_feature(
"product", bucket_size=5)