From 34a4b21f8f9dea64d3e99a97f639396f2d5556d3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 17 Oct 2017 16:19:30 -0700 Subject: Change GBDTClassifer to internally use twice differntiable implementation of multiclass cross entropy loss. PiperOrigin-RevId: 172532288 --- .../boosted_trees/estimator_batch/estimator.py | 15 +++++++++- tensorflow/contrib/boosted_trees/examples/mnist.py | 34 ++-------------------- .../contrib/boosted_trees/python/utils/losses.py | 5 +++- 3 files changed, 21 insertions(+), 33 deletions(-) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index f8028acbdb..01752416b3 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -19,8 +19,10 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.boosted_trees.estimator_batch import model +from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.ops import math_ops class GradientBoostedDecisionTreeClassifier(estimator.Estimator): @@ -65,10 +67,21 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): Raises: ValueError: If learner_config is not valid. """ + if n_classes > 2: + # For multi-class classification, use our loss implementation that + # supports second order derivative. + def loss_fn(labels, logits, weights=None): + result = losses.per_example_maxent_loss( + labels=labels, logits=logits, weights=weights, + num_classes=n_classes) + return math_ops.reduce_mean(result[0]) + else: + loss_fn = None head = head_lib.multi_class_head( n_classes=n_classes, weight_column_name=weight_column_name, - enable_centered_bias=False) + enable_centered_bias=False, + loss_fn=loss_fn) if learner_config.num_classes == 0: learner_config.num_classes = n_classes elif learner_config.num_classes != n_classes: diff --git a/tensorflow/contrib/boosted_trees/examples/mnist.py b/tensorflow/contrib/boosted_trees/examples/mnist.py index a3b1cb5154..0539d77720 100644 --- a/tensorflow/contrib/boosted_trees/examples/mnist.py +++ b/tensorflow/contrib/boosted_trees/examples/mnist.py @@ -35,18 +35,13 @@ from __future__ import division from __future__ import print_function import argparse -import functools import sys import numpy as np import tensorflow as tf -from tensorflow.contrib import metrics as metrics_lib -from tensorflow.contrib.boosted_trees.estimator_batch import custom_loss_head -from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeEstimator +from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier from tensorflow.contrib.boosted_trees.proto import learner_pb2 -from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn import learn_runner -from tensorflow.python.ops import math_ops def get_input_fn(dataset_split, @@ -88,36 +83,13 @@ def _get_tfbt(output_dir): learner_config.growing_mode = growing_mode run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) - # Use Cross Entropy loss (the impl in losses is twice differentiable). - loss_fn = functools.partial( - losses.per_example_maxent_loss, num_classes=num_classes) - logit_dim = num_classes learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) - # Since we use custom head, we need to tell how accuracy is calculated. - def _multiclass_metrics(predictions, labels, weights): - """Prepares eval metrics for multiclass eval.""" - metrics = dict() - logits = predictions["scores"] - classes = math_ops.argmax(logits, 1) - metrics["accuracy"] = metrics_lib.streaming_accuracy( - classes, labels, weights) - return metrics - - metrics_fn = _multiclass_metrics - # Use custom loss head so we can provide our loss (cross entropy for - # multiclass). - head = custom_loss_head.CustomLossHead( - loss_fn=loss_fn, - link_fn=tf.identity, - logit_dimension=logit_dim, - metrics_fn=metrics_fn) - # Create a TF Boosted trees estimator that can take in custom loss. - estimator = GradientBoostedDecisionTreeEstimator( + estimator = GradientBoostedDecisionTreeClassifier( learner_config=learner_config, - head=head, + n_classes=num_classes, examples_per_layer=FLAGS.examples_per_layer, model_dir=output_dir, num_trees=FLAGS.num_trees, diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 4f128b2301..1e8b3ac08a 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -101,7 +101,10 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): unweighted_loss = array_ops.expand_dims(-math_ops.log(probs_for_real_class), 1) - return unweighted_loss * weights, control_flow_ops.no_op() + if weights is None: + return unweighted_loss, control_flow_ops.no_op() + else: + return unweighted_loss * weights, control_flow_ops.no_op() def per_example_squared_loss(labels, weights, predictions): -- cgit v1.2.3