aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py70
1 files changed, 69 insertions, 1 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
index 9b7acfa664..839eedd3a8 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
@@ -28,10 +28,11 @@ from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import googletest
-
+from tensorflow.python.training import checkpoint_utils
def _train_input_fn():
features = {
@@ -156,5 +157,72 @@ class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
classifier.evaluate(input_fn=_eval_input_fn, steps=1)
+class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
+
+ def _assert_checkpoint(self, model_dir, global_step):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
+
+ def testTrainEvaluateInferDoesNotThrowErrorWithNoDnnInput(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 3
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ est = estimator.CoreDNNBoostedTreeCombinedEstimator(
+ head=head_fn,
+ dnn_hidden_units=[1],
+ dnn_feature_columns=[core_feature_column.numeric_column("x")],
+ tree_learner_config=learner_config,
+ num_trees=1,
+ tree_examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ dnn_steps_to_train=10,
+ dnn_input_layer_to_tree=False,
+ tree_feature_columns=[core_feature_column.numeric_column("x")])
+
+ # Train for a few steps.
+ est.train(input_fn=_train_input_fn, steps=1000)
+ # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished
+ self._assert_checkpoint(est.model_dir, global_step=14)
+ res = est.evaluate(input_fn=_eval_input_fn, steps=1)
+ self.assertLess(0.5, res["auc"])
+ est.predict(input_fn=_eval_input_fn)
+
+ def testTrainEvaluateInferDoesNotThrowErrorWithDnnInput(self):
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 3
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ est = estimator.CoreDNNBoostedTreeCombinedEstimator(
+ head=head_fn,
+ dnn_hidden_units=[1],
+ dnn_feature_columns=[core_feature_column.numeric_column("x")],
+ tree_learner_config=learner_config,
+ num_trees=1,
+ tree_examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ dnn_steps_to_train=10,
+ dnn_input_layer_to_tree=True,
+ tree_feature_columns=[])
+
+ # Train for a few steps.
+ est.train(input_fn=_train_input_fn, steps=1000)
+ res = est.evaluate(input_fn=_eval_input_fn, steps=1)
+ self.assertLess(0.5, res["auc"])
+ est.predict(input_fn=_eval_input_fn)
+
+
if __name__ == "__main__":
googletest.main()