aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-16 15:23:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 15:32:49 -0700
commite45de7cbe36cd5f72a9abb5c828bde9dac700bd8 (patch)
tree5e095346bd2ada50208c2a39753abed354931185 /tensorflow/contrib/boosted_trees
parente2b990f194308a8d5bb1ca84f51cf87e63e2382a (diff)
Adding weighted categorical feature column support.
PiperOrigin-RevId: 209058978
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py61
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py15
2 files changed, 76 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 68d710d713..c155128c0e 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -16,7 +16,10 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
import tempfile
+import numpy as np
+
from tensorflow.contrib.boosted_trees.estimator_batch import estimator
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column
@@ -26,6 +29,7 @@ from tensorflow.python.feature_column import feature_column_lib as core_feature_
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
@@ -473,6 +477,63 @@ class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase):
classifier.evaluate(input_fn=_multiclass_train_input_fn, steps=1)
classifier.predict(input_fn=_eval_input_fn)
+ def testWeightedCategoricalColumn(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ feature_columns = [
+ core_feature_column.weighted_categorical_column(
+ categorical_column=core_feature_column.
+ categorical_column_with_vocabulary_list(
+ key="word", vocabulary_list=["the", "cat", "dog"]),
+ weight_feature_key="weight")
+ ]
+
+ labels = np.array([[1], [1], [0], [0.]], dtype=np.float32)
+
+ def _make_input_fn():
+
+ def _input_fn():
+ features_dict = {}
+ # Sparse tensor representing
+ # example 0: "cat","the"
+ # examaple 1: "dog"
+ # example 2: -
+ # example 3: "the"
+ # Weights for the words are 5 - cat, 6- dog and 1 -the.
+ features_dict["word"] = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0], [3, 0]],
+ values=constant_op.constant(
+ ["the", "cat", "dog", "the"], dtype=dtypes.string),
+ dense_shape=[4, 3])
+ features_dict["weight"] = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0], [3, 0]],
+ values=[1., 5., 6., 1.],
+ dense_shape=[4, 3])
+ return features_dict, labels
+
+ return _input_fn
+
+ est = estimator.CoreGradientBoostedDecisionTreeEstimator(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=feature_columns)
+
+ input_fn = _make_input_fn()
+ est.train(input_fn=input_fn, steps=100)
+ est.evaluate(input_fn=input_fn, steps=1)
+ est.predict(input_fn=input_fn)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 20ff48c360..2f75d8aa99 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -218,6 +218,21 @@ def extract_features(features, feature_columns, use_core_columns):
sparse_int_shapes = []
for key in sorted(features.keys()):
tensor = features[key]
+ # TODO(nponomareva): consider iterating over feature columns instead.
+ if isinstance(tensor, tuple):
+ # Weighted categorical feature.
+ categorical_tensor = tensor[0]
+ weight_tensor = tensor[1]
+
+ shape = categorical_tensor.dense_shape
+ indices = array_ops.concat([
+ array_ops.slice(categorical_tensor.indices, [0, 0], [-1, 1]),
+ array_ops.expand_dims(
+ math_ops.to_int64(categorical_tensor.values), -1)
+ ], 1)
+ tensor = sparse_tensor.SparseTensor(
+ indices=indices, values=weight_tensor.values, dense_shape=shape)
+
if isinstance(tensor, sparse_tensor.SparseTensor):
if tensor.values.dtype == dtypes.float32:
sparse_float_names.append(key)