aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-11 08:15:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-11 09:32:43 -0700
commita9e8768274350cf5308c42a95e035addf1cfed2e (patch)
tree5f04b108f5cb7e6941e17c15863b9df93c239562
parentcd2162b3ea581e4dcda58303f61ec088e835ed72 (diff)
Support weighted sparse column in sdca optimizer.
Change: 127093435
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py27
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/sdca_optimizer.py11
2 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index b9e87fbacf..1215ce3728 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -203,6 +203,33 @@ class LinearClassifierTest(tf.test.TestCase):
scores = classifier.evaluate(input_fn=input_fn, steps=2)
self.assertGreater(scores['accuracy'], 0.9)
+ def testSdcaOptimizerWeightedSparseFeatures(self):
+ """LinearClasssifier with SDCAOptimizer and weighted sparse features."""
+
+ def input_fn():
+ return {
+ 'example_id': tf.constant(['1', '2', '3']),
+ 'price': tf.SparseTensor(values=[2., 3., 1.],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ shape=[3, 5]),
+ 'country': tf.SparseTensor(values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ shape=[3, 5])
+ }, tf.constant([[1], [0], [1]])
+
+ country = tf.contrib.layers.sparse_column_with_hash_bucket(
+ 'country', hash_bucket_size=5)
+ country_weighted_by_price = tf.contrib.layers.weighted_sparse_column(
+ country, 'price')
+ sdca_optimizer = tf.contrib.learn.SDCAOptimizer(
+ example_id_column='example_id')
+ classifier = tf.contrib.learn.LinearClassifier(
+ feature_columns=[country_weighted_by_price],
+ optimizer=sdca_optimizer)
+ classifier.fit(input_fn=input_fn, steps=50)
+ scores = classifier.evaluate(input_fn=input_fn, steps=2)
+ self.assertGreater(scores['accuracy'], 0.9)
+
def testSdcaOptimizerCrossedFeatures(self):
"""Tests LinearClasssifier with SDCAOptimizer and crossed features."""
diff --git a/tensorflow/contrib/learn/python/learn/estimators/sdca_optimizer.py b/tensorflow/contrib/learn/python/learn/estimators/sdca_optimizer.py
index 093c5d9875..948645d0e6 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/sdca_optimizer.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/sdca_optimizer.py
@@ -124,6 +124,17 @@ class SDCAOptimizer(object):
column.length)
sparse_features.append(math_ops.to_float(sparse_features_tensor))
sparse_features_weights.append(columns_to_variables[column][0])
+ elif isinstance(
+ column,
+ layers.feature_column._WeightedSparseColumn): # pylint: disable=protected-access
+ id_tensor = column.id_tensor(transformed_tensor)
+ weight_tensor = column.weight_tensor(transformed_tensor)
+ sparse_features_tensor = sparse_ops.sparse_merge(
+ id_tensor, weight_tensor, column.length,
+ name="{}_sparse_merge".format(column.name))
+ sparse_features.append(math_ops.to_float(
+ sparse_features_tensor, name="{}_to_float".format(column.name)))
+ sparse_features_weights.append(columns_to_variables[column][0])
else:
raise ValueError("SDCAOptimizer does not support column type %s." %
type(column).__name__)