aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 75ef1b0500..2c2dcb039d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -37,12 +37,31 @@ def _train_input_fn():
return features, label
+def _ranking_train_input_fn():
+ features = {
+ "a.f1": constant_op.constant([[3.], [0.3], [1.]]),
+ "a.f2": constant_op.constant([[0.1], [3.], [1.]]),
+ "b.f1": constant_op.constant([[13.], [0.4], [5.]]),
+ "b.f2": constant_op.constant([[1.], [3.], [0.01]]),
+ }
+ label = constant_op.constant([[0], [0], [1]], dtype=dtypes.int32)
+ return features, label
+
+
def _eval_input_fn():
features = {"x": constant_op.constant([[1.], [2.], [2.]])}
label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32)
return features, label
+def _infer_ranking_train_input_fn():
+ features = {
+ "f1": constant_op.constant([[3.], [2], [1.]]),
+ "f2": constant_op.constant([[0.1], [3.], [1.]])
+ }
+ return features, None
+
+
class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
def setUp(self):
@@ -155,6 +174,34 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
regressor.evaluate(input_fn=_eval_input_fn, steps=1)
regressor.export(self._export_dir_base)
+ def testRankingDontThrowExceptionForForEstimator(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+
+ model = estimator.GradientBoostedDecisionTreeRanker(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ use_core_libs=True,
+ feature_columns=[
+ core_feature_column.numeric_column("f1"),
+ core_feature_column.numeric_column("f2")
+ ],
+ ranking_model_pair_keys=("a", "b"))
+
+ model.fit(input_fn=_ranking_train_input_fn, steps=1000)
+ model.evaluate(input_fn=_ranking_train_input_fn, steps=1)
+ model.predict(input_fn=_infer_ranking_train_input_fn)
+
if __name__ == "__main__":
googletest.main()