aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-17 16:19:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-17 16:30:47 -0700
commit34a4b21f8f9dea64d3e99a97f639396f2d5556d3 (patch)
tree2020566b3dd6d90232f760cb462459a8147b6f69
parent47e4d4b6b5742350233a8fd83cd81269792ed286 (diff)
Change GBDTClassifer to internally use twice differntiable implementation of
multiclass cross entropy loss. PiperOrigin-RevId: 172532288
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py15
-rw-r--r--tensorflow/contrib/boosted_trees/examples/mnist.py34
-rw-r--r--tensorflow/contrib/boosted_trees/python/utils/losses.py5
3 files changed, 21 insertions, 33 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index f8028acbdb..01752416b3 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -19,8 +19,10 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.boosted_trees.estimator_batch import model
+from tensorflow.contrib.boosted_trees.python.utils import losses
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
+from tensorflow.python.ops import math_ops
class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
@@ -65,10 +67,21 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
Raises:
ValueError: If learner_config is not valid.
"""
+ if n_classes > 2:
+ # For multi-class classification, use our loss implementation that
+ # supports second order derivative.
+ def loss_fn(labels, logits, weights=None):
+ result = losses.per_example_maxent_loss(
+ labels=labels, logits=logits, weights=weights,
+ num_classes=n_classes)
+ return math_ops.reduce_mean(result[0])
+ else:
+ loss_fn = None
head = head_lib.multi_class_head(
n_classes=n_classes,
weight_column_name=weight_column_name,
- enable_centered_bias=False)
+ enable_centered_bias=False,
+ loss_fn=loss_fn)
if learner_config.num_classes == 0:
learner_config.num_classes = n_classes
elif learner_config.num_classes != n_classes:
diff --git a/tensorflow/contrib/boosted_trees/examples/mnist.py b/tensorflow/contrib/boosted_trees/examples/mnist.py
index a3b1cb5154..0539d77720 100644
--- a/tensorflow/contrib/boosted_trees/examples/mnist.py
+++ b/tensorflow/contrib/boosted_trees/examples/mnist.py
@@ -35,18 +35,13 @@ from __future__ import division
from __future__ import print_function
import argparse
-import functools
import sys
import numpy as np
import tensorflow as tf
-from tensorflow.contrib import metrics as metrics_lib
-from tensorflow.contrib.boosted_trees.estimator_batch import custom_loss_head
-from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeEstimator
+from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier
from tensorflow.contrib.boosted_trees.proto import learner_pb2
-from tensorflow.contrib.boosted_trees.python.utils import losses
from tensorflow.contrib.learn import learn_runner
-from tensorflow.python.ops import math_ops
def get_input_fn(dataset_split,
@@ -88,36 +83,13 @@ def _get_tfbt(output_dir):
learner_config.growing_mode = growing_mode
run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
- # Use Cross Entropy loss (the impl in losses is twice differentiable).
- loss_fn = functools.partial(
- losses.per_example_maxent_loss, num_classes=num_classes)
- logit_dim = num_classes
learner_config.multi_class_strategy = (
learner_pb2.LearnerConfig.DIAGONAL_HESSIAN)
- # Since we use custom head, we need to tell how accuracy is calculated.
- def _multiclass_metrics(predictions, labels, weights):
- """Prepares eval metrics for multiclass eval."""
- metrics = dict()
- logits = predictions["scores"]
- classes = math_ops.argmax(logits, 1)
- metrics["accuracy"] = metrics_lib.streaming_accuracy(
- classes, labels, weights)
- return metrics
-
- metrics_fn = _multiclass_metrics
- # Use custom loss head so we can provide our loss (cross entropy for
- # multiclass).
- head = custom_loss_head.CustomLossHead(
- loss_fn=loss_fn,
- link_fn=tf.identity,
- logit_dimension=logit_dim,
- metrics_fn=metrics_fn)
-
# Create a TF Boosted trees estimator that can take in custom loss.
- estimator = GradientBoostedDecisionTreeEstimator(
+ estimator = GradientBoostedDecisionTreeClassifier(
learner_config=learner_config,
- head=head,
+ n_classes=num_classes,
examples_per_layer=FLAGS.examples_per_layer,
model_dir=output_dir,
num_trees=FLAGS.num_trees,
diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py
index 4f128b2301..1e8b3ac08a 100644
--- a/tensorflow/contrib/boosted_trees/python/utils/losses.py
+++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py
@@ -101,7 +101,10 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15):
unweighted_loss = array_ops.expand_dims(-math_ops.log(probs_for_real_class),
1)
- return unweighted_loss * weights, control_flow_ops.no_op()
+ if weights is None:
+ return unweighted_loss, control_flow_ops.no_op()
+ else:
+ return unweighted_loss * weights, control_flow_ops.no_op()
def per_example_squared_loss(labels, weights, predictions):