aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py29
1 files changed, 28 insertions, 1 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 2c2dcb039d..f787d3cdb8 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -182,7 +182,7 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
config = run_config.RunConfig()
head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
- loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
model = estimator.GradientBoostedDecisionTreeRanker(
head=head_fn,
@@ -203,5 +203,32 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
model.predict(input_fn=_infer_ranking_train_input_fn)
+class CoreGradientBoostedDecisionTreeEstimator(test_util.TensorFlowTestCase):
+
+ def testTrainEvaluateInferDoesNotThrowError(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ est = estimator.CoreGradientBoostedDecisionTreeEstimator(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[core_feature_column.numeric_column("x")])
+
+ # Train for a few steps.
+ est.train(input_fn=_train_input_fn, steps=1000)
+ est.evaluate(input_fn=_eval_input_fn, steps=1)
+ est.predict(input_fn=_eval_input_fn)
+
+
if __name__ == "__main__":
googletest.main()