aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linear_optimizer
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-30 13:51:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 13:55:05 -0700
commitb7cb5fa0059b2ef6a40aa15a4d97a01ba2e57d85 (patch)
treefe1b4d0c6655a3ea421fbd59c46b82a8f54cf253 /tensorflow/contrib/linear_optimizer
parent57b7c7befa52ee4a205536c0552422a750cbcd21 (diff)
Extend SDCAOptimizer functionality to prune negative indices (the default value for OOV with tf.feature_column.FeatureColumn, sparse / categorical).
PiperOrigin-RevId: 194839178
Diffstat (limited to 'tensorflow/contrib/linear_optimizer')
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
index 213c2eced5..12039ecc6f 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
@@ -198,6 +198,14 @@ class SDCAOptimizer(object):
example_ids = array_ops.reshape(id_tensor.indices[:, 0], [-1])
flat_ids = array_ops.reshape(id_tensor.values, [-1])
+ # Prune invalid IDs (< 0) from the flat_ids, example_ids, and
+ # weight_tensor. These can come from looking up an OOV entry in the
+ # vocabulary (default value being -1).
+ is_id_valid = math_ops.greater_equal(flat_ids, 0)
+ flat_ids = array_ops.boolean_mask(flat_ids, is_id_valid)
+ example_ids = array_ops.boolean_mask(example_ids, is_id_valid)
+ weight_tensor = array_ops.boolean_mask(weight_tensor, is_id_valid)
+
projection_length = math_ops.reduce_max(flat_ids) + 1
# project ids based on example ids so that we can dedup ids that
# occur multiple times for a single example.