diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-30 13:51:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-30 13:55:05 -0700 |
commit | b7cb5fa0059b2ef6a40aa15a4d97a01ba2e57d85 (patch) | |
tree | fe1b4d0c6655a3ea421fbd59c46b82a8f54cf253 /tensorflow/contrib/linear_optimizer | |
parent | 57b7c7befa52ee4a205536c0552422a750cbcd21 (diff) |
Extend SDCAOptimizer functionality to prune negative indices (the default value for OOV with tf.feature_column.FeatureColumn, sparse / categorical).
PiperOrigin-RevId: 194839178
Diffstat (limited to 'tensorflow/contrib/linear_optimizer')
-rw-r--r-- | tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py index 213c2eced5..12039ecc6f 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py @@ -198,6 +198,14 @@ class SDCAOptimizer(object): example_ids = array_ops.reshape(id_tensor.indices[:, 0], [-1]) flat_ids = array_ops.reshape(id_tensor.values, [-1]) + # Prune invalid IDs (< 0) from the flat_ids, example_ids, and + # weight_tensor. These can come from looking up an OOV entry in the + # vocabulary (default value being -1). + is_id_valid = math_ops.greater_equal(flat_ids, 0) + flat_ids = array_ops.boolean_mask(flat_ids, is_id_valid) + example_ids = array_ops.boolean_mask(example_ids, is_id_valid) + weight_tensor = array_ops.boolean_mask(weight_tensor, is_id_valid) + projection_length = math_ops.reduce_max(flat_ids) + 1 # project ids based on example ids so that we can dedup ids that # occur multiple times for a single example. |