aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/feature_column/feature_column_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/feature_column/feature_column_test.py')
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index b14ec73ba2..3057776391 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -3206,6 +3206,20 @@ class IndicatorColumnTest(test.TestCase):
with _initialized_session():
self.assertAllEqual([[0, 0, 1], [1, 0, 0]], indicator_tensor.eval())
+ def test_transform_with_weighted_column(self):
+ # Github issue 12557
+ ids = fc.categorical_column_with_vocabulary_list(
+ key='ids', vocabulary_list=('a', 'b', 'c'))
+ weights = fc.weighted_categorical_column(ids, 'weights')
+ indicator = fc.indicator_column(weights)
+ features = {
+ 'ids': constant_op.constant(['c', 'b', 'a'], shape=(1, 3)),
+ 'weights': constant_op.constant([2., 4., 6.], shape=(1, 3))
+ }
+ indicator_tensor = _transform_features(features, [indicator])[indicator]
+ with _initialized_session():
+ self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval())
+
def test_linear_model(self):
animal = fc.indicator_column(
fc.categorical_column_with_identity('animal', num_buckets=4))