diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-30 13:51:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-30 13:55:05 -0700 |
commit | b7cb5fa0059b2ef6a40aa15a4d97a01ba2e57d85 (patch) | |
tree | fe1b4d0c6655a3ea421fbd59c46b82a8f54cf253 /tensorflow/contrib/learn | |
parent | 57b7c7befa52ee4a205536c0552422a750cbcd21 (diff) |
Extend SDCAOptimizer functionality to prune negative indices (the default value for OOV with tf.feature_column.FeatureColumn, sparse / categorical).
PiperOrigin-RevId: 194839178
Diffstat (limited to 'tensorflow/contrib/learn')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/linear_test.py | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py index d3bb0fda57..0a863f0e20 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear_test.py @@ -863,6 +863,38 @@ class LinearClassifierTest(test.TestCase): scores = classifier.evaluate(input_fn=input_fn, steps=1) self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerWeightedSparseFeaturesOOVWithNoOOVBuckets(self): + """LinearClassifier with SDCAOptimizer with OOV features (-1 IDs).""" + + def input_fn(): + return { + 'example_id': + constant_op.constant(['1', '2', '3']), + 'price': + sparse_tensor.SparseTensor( + values=[2., 3., 1.], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 5]), + 'country': + sparse_tensor.SparseTensor( + # 'GB' is out of the vocabulary. + values=['IT', 'US', 'GB'], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 5]) + }, constant_op.constant([[1], [0], [1]]) + + country = feature_column_lib.sparse_column_with_keys( + 'country', keys=['US', 'CA', 'MK', 'IT', 'CN']) + country_weighted_by_price = feature_column_lib.weighted_sparse_column( + country, 'price') + sdca_optimizer = sdca_optimizer_lib.SDCAOptimizer( + example_id_column='example_id') + classifier = linear.LinearClassifier( + feature_columns=[country_weighted_by_price], optimizer=sdca_optimizer) + classifier.fit(input_fn=input_fn, steps=50) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + self.assertGreater(scores['accuracy'], 0.9) + def testSdcaOptimizerCrossedFeatures(self): """Tests LinearClassifier with SDCAOptimizer and crossed features.""" |