aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-10 15:23:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 15:26:04 -0700
commit0172f3b5b86ccdf32366259a31266a988a9445d5 (patch)
treeaf575720f0b674f0e92c3df397ac23289ae403e9 /tensorflow/python/feature_column
parent99e198185d3a4a8bb089102b71b9fc3920427887 (diff)
Allow negative feature values in computation for `sum` combiner.
PiperOrigin-RevId: 192355950
Diffstat (limited to 'tensorflow/python/feature_column')
-rw-r--r--tensorflow/python/feature_column/feature_column.py15
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py57
2 files changed, 56 insertions, 16 deletions
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 7a104fa4ac..f9201a4794 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -3148,6 +3148,9 @@ def _safe_embedding_lookup_sparse(embedding_weights,
# Prune invalid ids and weights.
sparse_ids, sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
+ if combiner != 'sum':
+ sparse_ids, sparse_weights = _prune_invalid_weights(
+ sparse_ids, sparse_weights)
# Fill in dummy values for empty features, if necessary.
sparse_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,
@@ -3196,13 +3199,23 @@ def _prune_invalid_ids(sparse_ids, sparse_weights):
is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
if sparse_weights is not None:
is_id_valid = math_ops.logical_and(
- is_id_valid, math_ops.greater(sparse_weights.values, 0))
+ is_id_valid,
+ array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
if sparse_weights is not None:
sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
return sparse_ids, sparse_weights
+def _prune_invalid_weights(sparse_ids, sparse_weights):
+ """Prune invalid weights (< 0) from the input ids and weights."""
+ if sparse_weights is not None:
+ is_weights_valid = math_ops.greater(sparse_weights.values, 0)
+ sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
+ sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
+ return sparse_ids, sparse_weights
+
+
class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
collections.namedtuple('_IndicatorColumn',
['categorical_column'])):
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index 07588af37e..62718db0e5 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -1511,6 +1511,28 @@ class LinearModelTest(test.TestCase):
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [5010.]], predictions.eval())
+ def test_sparse_combiner_with_negative_weights(self):
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast_weights = fc.weighted_categorical_column(wire_cast, 'weights')
+
+ with ops.Graph().as_default():
+ wire_tensor = sparse_tensor.SparseTensor(
+ values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2])
+ features = {
+ 'wire_cast': wire_tensor,
+ 'weights': constant_op.constant([[1., 1., -1.0]])
+ }
+ predictions = fc.linear_model(
+ features, [wire_cast_weights], sparse_combiner='sum')
+ bias = get_linear_model_bias()
+ wire_cast_var = get_linear_model_column_var(wire_cast)
+ with _initialized_session() as sess:
+ sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
+ sess.run(bias.assign([5.]))
+ self.assertAllClose([[1005.], [-9985.]], predictions.eval())
+
def test_dense_multi_dimension_multi_output(self):
price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
@@ -6164,14 +6186,16 @@ class WeightedCategoricalColumnTest(test.TestCase):
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,))
+ predictions = get_keras_linear_model_predictions(
+ {
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,))
+ }, (column,),
+ sparse_combiner='mean')
with _initialized_session():
with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
predictions.eval()
@@ -6255,13 +6279,16 @@ class WeightedCategoricalColumnTest(test.TestCase):
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,))
+ predictions = fc.linear_model(
+ {
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
+ 'values': ((.5,), (1.,))
+ }, (column,),
+ sparse_combiner='mean')
with _initialized_session():
with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
predictions.eval()