aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/learn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-30 13:51:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 13:55:05 -0700
commitb7cb5fa0059b2ef6a40aa15a4d97a01ba2e57d85 (patch)
treefe1b4d0c6655a3ea421fbd59c46b82a8f54cf253 /tensorflow/contrib/learn
parent57b7c7befa52ee4a205536c0552422a750cbcd21 (diff)
Extend SDCAOptimizer functionality to prune negative indices (the default value for OOV with tf.feature_column.FeatureColumn, sparse / categorical).
PiperOrigin-RevId: 194839178
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear_test.py32
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
index d3bb0fda57..0a863f0e20 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py
@@ -863,6 +863,38 @@ class LinearClassifierTest(test.TestCase):
scores = classifier.evaluate(input_fn=input_fn, steps=1)
self.assertGreater(scores['accuracy'], 0.9)
+ def testSdcaOptimizerWeightedSparseFeaturesOOVWithNoOOVBuckets(self):
+ """LinearClassifier with SDCAOptimizer with OOV features (-1 IDs)."""
+
+ def input_fn():
+ return {
+ 'example_id':
+ constant_op.constant(['1', '2', '3']),
+ 'price':
+ sparse_tensor.SparseTensor(
+ values=[2., 3., 1.],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ dense_shape=[3, 5]),
+ 'country':
+ sparse_tensor.SparseTensor(
+ # 'GB' is out of the vocabulary.
+ values=['IT', 'US', 'GB'],
+ indices=[[0, 0], [1, 0], [2, 0]],
+ dense_shape=[3, 5])
+ }, constant_op.constant([[1], [0], [1]])
+
+ country = feature_column_lib.sparse_column_with_keys(
+ 'country', keys=['US', 'CA', 'MK', 'IT', 'CN'])
+ country_weighted_by_price = feature_column_lib.weighted_sparse_column(
+ country, 'price')
+ sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer(
+ example_id_column='example_id')
+ classifier = linear.LinearClassifier(
+ feature_columns=[country_weighted_by_price], optimizer=sdca_optimizer)
+ classifier.fit(input_fn=input_fn, steps=50)
+ scores = classifier.evaluate(input_fn=input_fn, steps=1)
+ self.assertGreater(scores['accuracy'], 0.9)
+
def testSdcaOptimizerCrossedFeatures(self):
"""Tests LinearClassifier with SDCAOptimizer and crossed features."""