diff options
author | 2018-09-19 11:08:53 -0700 | |
---|---|---|
committer | 2018-09-19 11:12:23 -0700 | |
commit | a586140da6d0460bbf18384556d5cc449b67b322 (patch) | |
tree | 0d645778e79fbed3eb3cb0e3004d949e585e283c /tensorflow/contrib/estimator/python | |
parent | 5330ede39fa2f1f7b3302bc316061baf180fab44 (diff) |
Python interface for Boosted Trees model explainability (currently includes directional feature contributions); fixed ExampleDebugOutputs bug where it errors with empty trees.
PiperOrigin-RevId: 213658470
Diffstat (limited to 'tensorflow/contrib/estimator/python')
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/boosted_trees.py | 28 | ||||
-rw-r--r-- | tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py | 74 |
2 files changed, 98 insertions, 4 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py index 7ed77bcce6..11f60c8238 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees +from tensorflow.python.estimator.canned import head as head_lib def _validate_input_fn_and_repeat_dataset(train_input_fn): @@ -33,7 +34,18 @@ def _validate_input_fn_and_repeat_dataset(train_input_fn): return _input_fn -class _BoostedTreesEstimator(estimator.Estimator): +# pylint: disable=protected-access +def _is_classification_head(head): + """Infers if the head is a classification head.""" + # Check using all classification heads defined in canned/head.py. However, it + # is not a complete list - it does not check for other classification heads + # not defined in the head library. + return isinstance(head, + (head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss, + head_lib._MultiClassHeadWithSoftmaxCrossEntropyLoss)) + + +class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): """An Estimator for Tensorflow Boosted Trees models.""" def __init__(self, @@ -96,8 +108,10 @@ class _BoostedTreesEstimator(estimator.Estimator): negative gain). For pre and post pruning, you MUST provide tree_complexity >0. + Raises: + ValueError: when wrong arguments are given or unsupported functionalities + are requested. """ - # pylint:disable=protected-access # HParams for the model. tree_hparams = canned_boosted_trees._TreeHParams( n_trees, max_depth, learning_rate, l1_regularization, l2_regularization, @@ -115,8 +129,14 @@ class _BoostedTreesEstimator(estimator.Estimator): config=config) super(_BoostedTreesEstimator, self).__init__( - model_fn=_model_fn, model_dir=model_dir, config=config) - # pylint:enable=protected-access + model_fn=_model_fn, + model_dir=model_dir, + config=config, + feature_columns=feature_columns, + head=head, + center_bias=center_bias, + is_classification=_is_classification_head(head)) + # pylint: enable=protected-access def boosted_trees_classifier_train_in_memory( diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py index b1581f3750..e23d9c0fc4 100644 --- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py +++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py @@ -360,5 +360,79 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase): [pred['predictions'] for pred in predictions]) +class BoostedTreesDebugOutputTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._head = canned_boosted_trees._create_regression_head(label_dimension=1) + self._feature_columns = { + feature_column.bucketized_column( + feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32), + BUCKET_BOUNDARIES) for i in range(NUM_FEATURES) + } + + def testContribEstimatorThatDFCIsInPredictions(self): + # pylint:disable=protected-access + head = canned_boosted_trees._create_regression_head(label_dimension=1) + train_input_fn = _make_train_input_fn(is_classification=False) + predict_input_fn = numpy_io.numpy_input_fn( + x=FEATURES_DICT, y=None, batch_size=1, num_epochs=1, shuffle=False) + + est = boosted_trees._BoostedTreesEstimator( + feature_columns=self._feature_columns, + n_batches_per_layer=1, + head=head, + n_trees=1, + max_depth=5, + center_bias=True) + # pylint:enable=protected-access + + num_steps = 100 + # Train for a few steps. Validate debug outputs in prediction dicts. + est.train(train_input_fn, steps=num_steps) + debug_predictions = est.experimental_predict_with_explanations( + predict_input_fn) + biases, dfcs = zip(*[(pred['bias'], pred['dfc']) + for pred in debug_predictions]) + self.assertAllClose([1.8] * 5, biases) + self.assertAllClose(({ + 0: -0.070499420166015625, + 1: -0.095000028610229492, + 2: 0.0 + }, { + 0: -0.53763031959533691, + 1: 0.063333392143249512, + 2: 0.0 + }, { + 0: -0.51756942272186279, + 1: -0.095000028610229492, + 2: 0.0 + }, { + 0: 0.1563495397567749, + 1: 0.063333392143249512, + 2: 0.0 + }, { + 0: 0.96934974193572998, + 1: 0.063333392143249512, + 2: 0.0 + }), dfcs) + + # Assert sum(dfcs) + bias == predictions. + expected_predictions = [[1.6345005], [1.32570302], [1.1874305], + [2.01968288], [2.83268309]] + predictions = [ + [sum(dfc.values()) + bias] for (dfc, bias) in zip(dfcs, biases) + ] + self.assertAllClose(expected_predictions, predictions) + + # Test when user doesn't include bias or dfc in predict_keys. + debug_predictions = est.experimental_predict_with_explanations( + predict_input_fn, predict_keys=['predictions']) + for prediction_dict in debug_predictions: + self.assertTrue('bias' in prediction_dict) + self.assertTrue('dfc' in prediction_dict) + self.assertTrue('predictions' in prediction_dict) + self.assertEqual(len(prediction_dict), 3) + + if __name__ == '__main__': googletest.main() |