aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-26 10:45:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 10:49:39 -0700
commit0a3155f7fbf56df5e81c7cbf35afd45173359635 (patch)
treeaf1a7ed1e79aceecd03e5fa4f1134fd90d57d866 /tensorflow/contrib/boosted_trees
parent7871a8c13b8998cc1e06ce34fe54cad832a6f78e (diff)
Adding core estimator for a fusion model.
PiperOrigin-RevId: 206183643
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/BUILD1
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py265
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py70
3 files changed, 272 insertions, 64 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index ef0e80cd09..f4a375328e 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -147,6 +147,7 @@ py_library(
deps = [
":distillation_loss",
":estimator_utils",
+ ":model",
":trainer_hooks",
"//tensorflow/contrib/boosted_trees:gbdt_batch",
"//tensorflow/contrib/boosted_trees:model_ops_py",
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
index 7eb429b636..dbfa69edcb 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
@@ -26,6 +26,7 @@ from __future__ import print_function
import six
from tensorflow.contrib import layers
+from tensorflow.contrib.boosted_trees.estimator_batch import model
from tensorflow.contrib.boosted_trees.estimator_batch import distillation_loss
from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
@@ -34,6 +35,7 @@ from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batc
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.python.estimator import estimator as core_estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import ops
@@ -62,27 +64,29 @@ def _add_hidden_layer_summary(value, tag):
summary.histogram("%s_activation" % tag, value)
-def _dnn_tree_combined_model_fn(features,
- labels,
- mode,
- head,
- dnn_hidden_units,
- dnn_feature_columns,
- tree_learner_config,
- num_trees,
- tree_examples_per_layer,
- config=None,
- dnn_optimizer="Adagrad",
- dnn_activation_fn=nn.relu,
- dnn_dropout=None,
- dnn_input_layer_partitioner=None,
- dnn_input_layer_to_tree=True,
- dnn_steps_to_train=10000,
- predict_with_tree_only=False,
- tree_feature_columns=None,
- tree_center_bias=False,
- dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+def _dnn_tree_combined_model_fn(
+ features,
+ labels,
+ mode,
+ head,
+ dnn_hidden_units,
+ dnn_feature_columns,
+ tree_learner_config,
+ num_trees,
+ tree_examples_per_layer,
+ config=None,
+ dnn_optimizer="Adagrad",
+ dnn_activation_fn=nn.relu,
+ dnn_dropout=None,
+ dnn_input_layer_partitioner=None,
+ dnn_input_layer_to_tree=True,
+ dnn_steps_to_train=10000,
+ predict_with_tree_only=False,
+ tree_feature_columns=None,
+ tree_center_bias=False,
+ dnn_to_tree_distillation_param=None,
+ use_core_versions=False,
+ output_type=model.ModelBuilderOutputType.MODEL_FN_OPS):
"""DNN and GBDT combined model_fn.
Args:
@@ -156,6 +160,10 @@ def _dnn_tree_combined_model_fn(features,
partitioned_variables.min_max_variable_partitioner(
max_partitions=config.num_ps_replicas, min_slice_size=64 << 20))
+ if (output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC and
+ not use_core_versions):
+ raise ValueError("You must use core versions with Estimator Spec")
+
with variable_scope.variable_scope(
dnn_parent_scope,
values=tuple(six.itervalues(features)),
@@ -235,7 +243,8 @@ def _dnn_tree_combined_model_fn(features,
learner_config=tree_learner_config,
feature_columns=tree_feature_columns,
logits_dimension=head.logits_dimension,
- features=tree_features)
+ features=tree_features,
+ use_core_columns=use_core_versions)
with ops.name_scope("gbdt"):
predictions_dict = gbdt_model.predict(mode)
@@ -284,63 +293,96 @@ def _dnn_tree_combined_model_fn(features,
del loss
return control_flow_ops.no_op()
- if use_core_versions:
- model_fn_ops = head.create_estimator_spec(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_no_train_op_fn,
- logits=tree_train_logits)
- dnn_train_op = head.create_estimator_spec(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_dnn_train_op_fn,
- logits=dnn_logits)
- dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops(
- dnn_train_op).train_op
+ if tree_center_bias:
+ num_trees += 1
+ finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
- tree_train_op = head.create_estimator_spec(
- features=tree_features,
- mode=mode,
- labels=labels,
- train_op_fn=_tree_train_op_fn,
- logits=tree_train_logits)
- tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops(
- tree_train_op).train_op
+ if output_type == model.ModelBuilderOutputType.MODEL_FN_OPS:
+ if use_core_versions:
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_no_train_op_fn,
+ logits=tree_train_logits)
+ dnn_train_op = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_dnn_train_op_fn,
+ logits=dnn_logits)
+ dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops(
+ dnn_train_op).train_op
- model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
- else:
- model_fn_ops = head.create_model_fn_ops(
+ tree_train_op = head.create_estimator_spec(
+ features=tree_features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_tree_train_op_fn,
+ logits=tree_train_logits)
+ tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops(
+ tree_train_op).train_op
+
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
+ model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_no_train_op_fn,
+ logits=tree_train_logits)
+ dnn_train_op = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_dnn_train_op_fn,
+ logits=dnn_logits).train_op
+ tree_train_op = head.create_model_fn_ops(
+ features=tree_features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_tree_train_op_fn,
+ logits=tree_train_logits).train_op
+
+ # Add the hooks
+ model_fn_ops.training_hooks.extend([
+ trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
+ tree_train_op),
+ trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
+ finalized_trees)
+ ])
+ return model_fn_ops
+
+ elif output_type == model.ModelBuilderOutputType.ESTIMATOR_SPEC:
+ fusion_spec = head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_no_train_op_fn,
logits=tree_train_logits)
- dnn_train_op = head.create_model_fn_ops(
+ dnn_spec = head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_dnn_train_op_fn,
- logits=dnn_logits).train_op
- tree_train_op = head.create_model_fn_ops(
+ logits=dnn_logits)
+ tree_spec = head.create_estimator_spec(
features=tree_features,
mode=mode,
labels=labels,
train_op_fn=_tree_train_op_fn,
- logits=tree_train_logits).train_op
-
- if tree_center_bias:
- num_trees += 1
- finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
-
- model_fn_ops.training_hooks.extend([
- trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
- tree_train_op),
- trainer_hooks.StopAfterNTrees(num_trees, attempted_trees, finalized_trees)
- ])
+ logits=tree_train_logits)
- return model_fn_ops
+ training_hooks = [
+ trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train,
+ tree_spec.train_op),
+ trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
+ finalized_trees)
+ ]
+ fusion_spec = fusion_spec._replace(training_hooks=training_hooks +
+ list(fusion_spec.training_hooks))
+ return fusion_spec
class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
@@ -697,3 +739,100 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
model_dir=model_dir,
config=config,
feature_engineering_fn=feature_engineering_fn)
+
+
+class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator):
+ """Initializes a core version of DNNBoostedTreeCombinedEstimator.
+
+ Args:
+ dnn_hidden_units: List of hidden units per layer for DNN.
+ dnn_feature_columns: An iterable containing all the feature columns
+ used by the model's DNN.
+ tree_learner_config: A config for the tree learner.
+ num_trees: Number of trees to grow model to after training DNN.
+ tree_examples_per_layer: Number of examples to accumulate before
+ growing the tree a layer. This value has a big impact on model
+ quality and should be set equal to the number of examples in
+ training dataset if possible. It can also be a function that computes
+ the number of examples based on the depth of the layer that's
+ being built.
+ head: `Head` instance.
+ model_dir: Directory for model exports.
+ config: `RunConfig` of the estimator.
+ dnn_optimizer: string, `Optimizer` object, or callable that defines the
+ optimizer to use for training the DNN. If `None`, will use the Adagrad
+ optimizer with default learning rate.
+ dnn_activation_fn: Activation function applied to each layer of the DNN.
+ If `None`, will use `tf.nn.relu`.
+ dnn_dropout: When not `None`, the probability to drop out a given
+ unit in the DNN.
+ dnn_input_layer_partitioner: Partitioner for input layer of the DNN.
+ Defaults to `min_max_variable_partitioner` with `min_slice_size`
+ 64 << 20.
+ dnn_input_layer_to_tree: Whether to provide the DNN's input layer
+ as a feature to the tree.
+ dnn_steps_to_train: Number of steps to train dnn for before switching
+ to gbdt.
+ predict_with_tree_only: Whether to use only the tree model output as the
+ final prediction.
+ tree_feature_columns: An iterable containing all the feature columns
+ used by the model's boosted trees. If dnn_input_layer_to_tree is
+ set to True, these features are in addition to dnn_feature_columns.
+ tree_center_bias: Whether a separate tree should be created for
+ first fitting the bias.
+ dnn_to_tree_distillation_param: A Tuple of (float, loss_fn), where the
+ float defines the weight of the distillation loss, and the loss_fn, for
+ computing distillation loss, takes dnn_logits, tree_logits and weight
+ tensor. If the entire tuple is None, no distillation will be applied. If
+ only the loss_fn is None, we will take the sigmoid/softmax cross entropy
+ loss be default. When distillation is applied, `predict_with_tree_only`
+ will be set to True.
+ """
+
+ def __init__(self,
+ dnn_hidden_units,
+ dnn_feature_columns,
+ tree_learner_config,
+ num_trees,
+ tree_examples_per_layer,
+ head,
+ model_dir=None,
+ config=None,
+ dnn_optimizer="Adagrad",
+ dnn_activation_fn=nn.relu,
+ dnn_dropout=None,
+ dnn_input_layer_partitioner=None,
+ dnn_input_layer_to_tree=True,
+ dnn_steps_to_train=10000,
+ predict_with_tree_only=False,
+ tree_feature_columns=None,
+ tree_center_bias=False,
+ dnn_to_tree_distillation_param=None):
+
+ def _model_fn(features, labels, mode, config):
+ return _dnn_tree_combined_model_fn(
+ features=features,
+ labels=labels,
+ mode=mode,
+ head=head,
+ dnn_hidden_units=dnn_hidden_units,
+ dnn_feature_columns=dnn_feature_columns,
+ tree_learner_config=tree_learner_config,
+ num_trees=num_trees,
+ tree_examples_per_layer=tree_examples_per_layer,
+ config=config,
+ dnn_optimizer=dnn_optimizer,
+ dnn_activation_fn=dnn_activation_fn,
+ dnn_dropout=dnn_dropout,
+ dnn_input_layer_partitioner=dnn_input_layer_partitioner,
+ dnn_input_layer_to_tree=dnn_input_layer_to_tree,
+ dnn_steps_to_train=dnn_steps_to_train,
+ predict_with_tree_only=predict_with_tree_only,
+ tree_feature_columns=tree_feature_columns,
+ tree_center_bias=tree_center_bias,
+ dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
+ output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC,
+ use_core_versions=True)
+
+ super(CoreDNNBoostedTreeCombinedEstimator, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
index 9b7acfa664..839eedd3a8 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
@@ -28,10 +28,11 @@ from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import googletest
-
+from tensorflow.python.training import checkpoint_utils
def _train_input_fn():
features = {
@@ -156,5 +157,72 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
classifier.evaluate(input_fn=_eval_input_fn, steps=1)
+class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
+
+ def _assert_checkpoint(self, model_dir, global_step):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
+
+ def testTrainEvaluateInferDoesNotThrowErrorWithNoDnnInput(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 3
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ est = estimator.CoreDNNBoostedTreeCombinedEstimator(
+ head=head_fn,
+ dnn_hidden_units=[1],
+ dnn_feature_columns=[core_feature_column.numeric_column("x")],
+ tree_learner_config=learner_config,
+ num_trees=1,
+ tree_examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ dnn_steps_to_train=10,
+ dnn_input_layer_to_tree=False,
+ tree_feature_columns=[core_feature_column.numeric_column("x")])
+
+ # Train for a few steps.
+ est.train(input_fn=_train_input_fn, steps=1000)
+ # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished
+ self._assert_checkpoint(est.model_dir, global_step=14)
+ res = est.evaluate(input_fn=_eval_input_fn, steps=1)
+ self.assertLess(0.5, res["auc"])
+ est.predict(input_fn=_eval_input_fn)
+
+ def testTrainEvaluateInferDoesNotThrowErrorWithDnnInput(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 3
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ est = estimator.CoreDNNBoostedTreeCombinedEstimator(
+ head=head_fn,
+ dnn_hidden_units=[1],
+ dnn_feature_columns=[core_feature_column.numeric_column("x")],
+ tree_learner_config=learner_config,
+ num_trees=1,
+ tree_examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ dnn_steps_to_train=10,
+ dnn_input_layer_to_tree=True,
+ tree_feature_columns=[])
+
+ # Train for a few steps.
+ est.train(input_fn=_train_input_fn, steps=1000)
+ res = est.evaluate(input_fn=_eval_input_fn, steps=1)
+ self.assertLess(0.5, res["auc"])
+ est.predict(input_fn=_eval_input_fn)
+
+
if __name__ == "__main__":
googletest.main()