diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-16 15:23:41 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 15:32:49 -0700 |
commit | e45de7cbe36cd5f72a9abb5c828bde9dac700bd8 (patch) | |
tree | 5e095346bd2ada50208c2a39753abed354931185 /tensorflow/contrib/boosted_trees | |
parent | e2b990f194308a8d5bb1ca84f51cf87e63e2382a (diff) |
Adding weighted categorical feature column support.
PiperOrigin-RevId: 209058978
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r-- | tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py | 61 | ||||
-rw-r--r-- | tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py | 15 |
2 files changed, 76 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py index 68d710d713..c155128c0e 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py @@ -16,7 +16,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + import tempfile +import numpy as np + from tensorflow.contrib.boosted_trees.estimator_batch import estimator from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column @@ -26,6 +29,7 @@ from tensorflow.python.feature_column import feature_column_lib as core_feature_ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops.losses import losses from tensorflow.python.platform import gfile @@ -473,6 +477,63 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase): classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1) classifier.predict(input_fn=_eval_input_fn) + def testWeightedCategoricalColumn(self): + head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + feature_columns = [ + core_feature_column.weighted_categorical_column( + categorical_column=core_feature_column. + categorical_column_with_vocabulary_list( + key="word", vocabulary_list=["the", "cat", "dog"]), + weight_feature_key="weight") + ] + + labels = np.array([[1], [1], [0], [0.]], dtype=np.float32) + + def _make_input_fn(): + + def _input_fn(): + features_dict = {} + # Sparse tensor representing + # example 0: "cat","the" + # examaple 1: "dog" + # example 2: - + # example 3: "the" + # Weights for the words are 5 - cat, 6- dog and 1 -the. + features_dict["word"] = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [3, 0]], + values=constant_op.constant( + ["the", "cat", "dog", "the"], dtype=dtypes.string), + dense_shape=[4, 3]) + features_dict["weight"] = sparse_tensor.SparseTensor( + indices=[[0, 0], [0, 1], [1, 0], [3, 0]], + values=[1., 5., 6., 1.], + dense_shape=[4, 3]) + return features_dict, labels + + return _input_fn + + est = estimator.CoreGradientBoostedDecisionTreeEstimator( + head=head_fn, + learner_config=learner_config, + num_trees=1, + examples_per_layer=3, + model_dir=model_dir, + config=config, + feature_columns=feature_columns) + + input_fn = _make_input_fn() + est.train(input_fn=input_fn, steps=100) + est.evaluate(input_fn=input_fn, steps=1) + est.predict(input_fn=input_fn) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 20ff48c360..2f75d8aa99 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -218,6 +218,21 @@ def extract_features(features, feature_columns, use_core_columns): sparse_int_shapes = [] for key in sorted(features.keys()): tensor = features[key] + # TODO(nponomareva): consider iterating over feature columns instead. + if isinstance(tensor, tuple): + # Weighted categorical feature. + categorical_tensor = tensor[0] + weight_tensor = tensor[1] + + shape = categorical_tensor.dense_shape + indices = array_ops.concat([ + array_ops.slice(categorical_tensor.indices, [0, 0], [-1, 1]), + array_ops.expand_dims( + math_ops.to_int64(categorical_tensor.values), -1) + ], 1) + tensor = sparse_tensor.SparseTensor( + indices=indices, values=weight_tensor.values, dense_shape=shape) + if isinstance(tensor, sparse_tensor.SparseTensor): if tensor.values.dtype == dtypes.float32: sparse_float_names.append(key) |