diff options
Diffstat (limited to 'tensorflow/python/feature_column/feature_column_test.py')
-rw-r--r-- | tensorflow/python/feature_column/feature_column_test.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py index b14ec73ba2..3057776391 100644 --- a/tensorflow/python/feature_column/feature_column_test.py +++ b/tensorflow/python/feature_column/feature_column_test.py @@ -3206,6 +3206,20 @@ class IndicatorColumnTest(test.TestCase): with _initialized_session(): self.assertAllEqual([[0, 0, 1], [1, 0, 0]], indicator_tensor.eval()) + def test_transform_with_weighted_column(self): + # Github issue 12557 + ids = fc.categorical_column_with_vocabulary_list( + key='ids', vocabulary_list=('a', 'b', 'c')) + weights = fc.weighted_categorical_column(ids, 'weights') + indicator = fc.indicator_column(weights) + features = { + 'ids': constant_op.constant(['c', 'b', 'a'], shape=(1, 3)), + 'weights': constant_op.constant([2., 4., 6.], shape=(1, 3)) + } + indicator_tensor = _transform_features(features, [indicator])[indicator] + with _initialized_session(): + self.assertAllEqual([[6., 4., 2.]], indicator_tensor.eval()) + def test_linear_model(self): animal = fc.indicator_column( fc.categorical_column_with_identity('animal', num_buckets=4)) |