aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py97
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees.py33
-rw-r--r--tensorflow/python/estimator/canned/boosted_trees_test.py63
3 files changed, 71 insertions, 122 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
index e99a87f3b3..eee5910687 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.estimator.python.estimator import boosted_trees
+from tensorflow.core.kernels.boosted_trees import boosted_trees_pb2
from tensorflow.python.estimator.canned import boosted_trees as canned_boosted_trees
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.feature_column import feature_column
@@ -69,10 +70,18 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
for i in range(NUM_FEATURES)
}
- def _assert_checkpoint(self, model_dir, expected_global_step):
- self.assertEqual(expected_global_step,
- checkpoint_utils.load_variable(model_dir,
- ops.GraphKeys.GLOBAL_STEP))
+ def _assert_checkpoint(self, model_dir, global_step, finalized_trees,
+ attempted_layers):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
+ serialized = reader.get_tensor('boosted_trees:0_serialized')
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertEqual(
+ finalized_trees,
+ sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized]))
+ self.assertEqual(attempted_layers,
+ ensemble_proto.growing_metadata.num_layers_attempted)
def testTrainAndEvaluateEstimator(self):
input_fn = _make_train_input_fn(is_classification=False)
@@ -88,9 +97,10 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
num_steps = 100
# Train for a few steps, and validate final checkpoint.
est.train(input_fn, steps=num_steps)
- self._assert_checkpoint(est.model_dir, 11)
+ self._assert_checkpoint(
+ est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10)
eval_res = est.evaluate(input_fn=input_fn, steps=1)
- self.assertAllClose(eval_res['average_loss'], 0.913176)
+ self.assertAllClose(eval_res['average_loss'], 1.008551)
def testInferEstimator(self):
train_input_fn = _make_train_input_fn(is_classification=False)
@@ -108,31 +118,13 @@ class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
num_steps = 100
# Train for a few steps, and validate final checkpoint.
est.train(train_input_fn, steps=num_steps)
- self._assert_checkpoint(est.model_dir, 6)
-
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
+ # Validate predictions.
predictions = list(est.predict(input_fn=predict_input_fn))
- self.assertEquals(5, len(predictions))
- self.assertAllClose([0.703549], predictions[0]['predictions'])
- self.assertAllClose([0.266539], predictions[1]['predictions'])
- self.assertAllClose([0.256479], predictions[2]['predictions'])
- self.assertAllClose([1.088732], predictions[3]['predictions'])
- self.assertAllClose([1.901732], predictions[4]['predictions'])
-
-
-class BoostedTreesClassifierTrainInMemoryTest(test_util.TensorFlowTestCase):
-
- def setUp(self):
- self._feature_columns = {
- feature_column.bucketized_column(
- feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
- BUCKET_BOUNDARIES)
- for i in range(NUM_FEATURES)
- }
-
- def _assert_checkpoint(self, model_dir, expected_global_step):
- self.assertEqual(expected_global_step,
- checkpoint_utils.load_variable(model_dir,
- ops.GraphKeys.GLOBAL_STEP))
+ self.assertAllClose(
+ [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+ [pred['predictions'] for pred in predictions])
def testBinaryClassifierTrainInMemoryAndEvalAndInfer(self):
train_input_fn = _make_train_input_fn(is_classification=True)
@@ -145,36 +137,16 @@ class BoostedTreesClassifierTrainInMemoryTest(test_util.TensorFlowTestCase):
n_trees=1,
max_depth=5)
# It will stop after 5 steps because of the max depth and num trees.
- self._assert_checkpoint(est.model_dir, 6)
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
# Check eval.
eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
self.assertAllClose(eval_res['accuracy'], 1.0)
-
- # Check predict that all labels are correct.
+ # Validate predictions.
predictions = list(est.predict(input_fn=predict_input_fn))
- self.assertEquals(5, len(predictions))
- self.assertAllClose([0], predictions[0]['class_ids'])
- self.assertAllClose([1], predictions[1]['class_ids'])
- self.assertAllClose([1], predictions[2]['class_ids'])
- self.assertAllClose([0], predictions[3]['class_ids'])
- self.assertAllClose([0], predictions[4]['class_ids'])
-
-
-class BoostedTreesRegressorTrainInMemoryTest(test_util.TensorFlowTestCase):
-
- def setUp(self):
- self._feature_columns = {
- feature_column.bucketized_column(
- feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
- BUCKET_BOUNDARIES)
- for i in range(NUM_FEATURES)
- }
-
- def _assert_checkpoint(self, model_dir, expected_global_step):
- self.assertEqual(expected_global_step,
- checkpoint_utils.load_variable(model_dir,
- ops.GraphKeys.GLOBAL_STEP))
+ self.assertAllClose([[0], [1], [1], [0], [0]],
+ [pred['class_ids'] for pred in predictions])
def testRegressorTrainInMemoryAndEvalAndInfer(self):
train_input_fn = _make_train_input_fn(is_classification=False)
@@ -187,20 +159,17 @@ class BoostedTreesRegressorTrainInMemoryTest(test_util.TensorFlowTestCase):
n_trees=1,
max_depth=5)
# It will stop after 5 steps because of the max depth and num trees.
- self._assert_checkpoint(est.model_dir, 6)
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
# Check eval.
eval_res = est.evaluate(input_fn=train_input_fn, steps=1)
- self.assertAllClose(eval_res['average_loss'], 2.2136638)
-
+ self.assertAllClose(eval_res['average_loss'], 2.478283)
# Validate predictions.
predictions = list(est.predict(input_fn=predict_input_fn))
- self.assertEquals(5, len(predictions))
- self.assertAllClose([0.703549], predictions[0]['predictions'])
- self.assertAllClose([0.266539], predictions[1]['predictions'])
- self.assertAllClose([0.256479], predictions[2]['predictions'])
- self.assertAllClose([1.088732], predictions[3]['predictions'])
- self.assertAllClose([1.901732], predictions[4]['predictions'])
+ self.assertAllClose(
+ [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+ [pred['predictions'] for pred in predictions])
if __name__ == '__main__':
diff --git a/tensorflow/python/estimator/canned/boosted_trees.py b/tensorflow/python/estimator/canned/boosted_trees.py
index 500ea03ea7..c5d5455b1a 100644
--- a/tensorflow/python/estimator/canned/boosted_trees.py
+++ b/tensorflow/python/estimator/canned/boosted_trees.py
@@ -209,8 +209,8 @@ class _CacheTrainingStatesUsingVariables(object):
name='cache_insert')
-class StopAtAttemptsHook(session_run_hook.SessionRunHook):
- """Hook that requests stop at the number of trees."""
+class _StopAtAttemptsHook(session_run_hook.SessionRunHook):
+ """Hook that requests stop at the number of attempts."""
def __init__(self, num_finalized_trees_tensor, num_attempted_layers_tensor,
max_trees, max_depth):
@@ -224,25 +224,17 @@ class StopAtAttemptsHook(session_run_hook.SessionRunHook):
[self._num_finalized_trees_tensor, self._num_attempted_layers_tensor])
def after_run(self, run_context, run_values):
+ # num_* tensors should be retrieved by a separate session than the training
+ # one, in order to read the values after growing.
+ # So, if it's approaching to the limit, get the actual value by additional
+ # session.
num_finalized_trees, num_attempted_layers = run_values.results
+ if (num_finalized_trees >= self._max_trees - 1 or
+ num_attempted_layers > 2 * self._max_trees * self._max_depth - 1):
+ num_finalized_trees, num_attempted_layers = run_context.session.run(
+ [self._num_finalized_trees_tensor, self._num_attempted_layers_tensor])
if (num_finalized_trees >= self._max_trees or
- 1.0 * num_attempted_layers / self._max_depth > 2 * self._max_trees):
- run_context.request_stop()
-
-
-class StopAtNumTreesHook(session_run_hook.SessionRunHook):
- """Hook that requests stop at the number of trees."""
-
- def __init__(self, num_trees_tensor, max_trees):
- self._num_trees_tensor = num_trees_tensor
- self._max_trees = max_trees
-
- def before_run(self, run_context):
- return session_run_hook.SessionRunArgs(self._num_trees_tensor)
-
- def after_run(self, run_context, run_values):
- num_trees = run_values.results
- if num_trees > self._max_trees:
+ num_attempted_layers > 2 * self._max_trees * self._max_depth):
run_context.request_stop()
@@ -468,7 +460,8 @@ def _bt_model_fn(
# Add an early stop hook.
estimator_spec = estimator_spec._replace(
training_hooks=estimator_spec.training_hooks +
- (StopAtNumTreesHook(num_trees, tree_hparams.n_trees),))
+ (_StopAtAttemptsHook(num_finalized_trees, num_attempted_layers,
+ tree_hparams.n_trees, tree_hparams.max_depth),))
return estimator_spec
diff --git a/tensorflow/python/estimator/canned/boosted_trees_test.py b/tensorflow/python/estimator/canned/boosted_trees_test.py
index 01e5cc7a5d..625745a3f9 100644
--- a/tensorflow/python/estimator/canned/boosted_trees_test.py
+++ b/tensorflow/python/estimator/canned/boosted_trees_test.py
@@ -69,7 +69,7 @@ def _make_train_input_fn(is_classification):
return _input_fn
-class BoostedTreesClassifierTest(test_util.TensorFlowTestCase):
+class BoostedTreesEstimatorTest(test_util.TensorFlowTestCase):
def setUp(self):
self._feature_columns = {
@@ -79,10 +79,18 @@ class BoostedTreesClassifierTest(test_util.TensorFlowTestCase):
for i in range(NUM_FEATURES)
}
- def _assert_checkpoint(self, model_dir, expected_global_step):
- self.assertEqual(expected_global_step,
- checkpoint_utils.load_variable(model_dir,
- ops.GraphKeys.GLOBAL_STEP))
+ def _assert_checkpoint(self, model_dir, global_step, finalized_trees,
+ attempted_layers):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
+ serialized = reader.get_tensor('boosted_trees:0_serialized')
+ ensemble_proto = boosted_trees_pb2.TreeEnsemble()
+ ensemble_proto.ParseFromString(serialized)
+ self.assertEqual(
+ finalized_trees,
+ sum([1 for t in ensemble_proto.tree_metadata if t.is_finalized]))
+ self.assertEqual(attempted_layers,
+ ensemble_proto.growing_metadata.num_layers_attempted)
def testTrainAndEvaluateBinaryClassifier(self):
input_fn = _make_train_input_fn(is_classification=True)
@@ -97,7 +105,8 @@ class BoostedTreesClassifierTest(test_util.TensorFlowTestCase):
num_steps = 100
# Train for a few steps, and validate final checkpoint.
est.train(input_fn, steps=num_steps)
- self._assert_checkpoint(est.model_dir, 6)
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
eval_res = est.evaluate(input_fn=input_fn, steps=1)
self.assertAllClose(eval_res['accuracy'], 1.0)
@@ -118,29 +127,9 @@ class BoostedTreesClassifierTest(test_util.TensorFlowTestCase):
est.train(train_input_fn, steps=num_steps)
predictions = list(est.predict(input_fn=predict_input_fn))
- self.assertEquals(5, len(predictions))
# All labels are correct.
- self.assertAllClose([0], predictions[0]['class_ids'])
- self.assertAllClose([1], predictions[1]['class_ids'])
- self.assertAllClose([1], predictions[2]['class_ids'])
- self.assertAllClose([0], predictions[3]['class_ids'])
- self.assertAllClose([0], predictions[4]['class_ids'])
-
-
-class BoostedTreesRegressionTest(test_util.TensorFlowTestCase):
-
- def setUp(self):
- self._feature_columns = {
- feature_column.bucketized_column(
- feature_column.numeric_column('f_%d' % i, dtype=dtypes.float32),
- BUCKET_BOUNDARIES)
- for i in range(NUM_FEATURES)
- }
-
- def _assert_checkpoint(self, model_dir, expected_global_step):
- self.assertEqual(expected_global_step,
- checkpoint_utils.load_variable(model_dir,
- ops.GraphKeys.GLOBAL_STEP))
+ self.assertAllClose([[0], [1], [1], [0], [0]],
+ [pred['class_ids'] for pred in predictions])
def testTrainAndEvaluateRegressor(self):
input_fn = _make_train_input_fn(is_classification=False)
@@ -155,9 +144,10 @@ class BoostedTreesRegressionTest(test_util.TensorFlowTestCase):
num_steps = 100
# Train for a few steps, and validate final checkpoint.
est.train(input_fn, steps=num_steps)
- self._assert_checkpoint(est.model_dir, 11)
+ self._assert_checkpoint(
+ est.model_dir, global_step=10, finalized_trees=2, attempted_layers=10)
eval_res = est.evaluate(input_fn=input_fn, steps=1)
- self.assertAllClose(eval_res['average_loss'], 0.913176)
+ self.assertAllClose(eval_res['average_loss'], 1.008551)
def testInferRegressor(self):
train_input_fn = _make_train_input_fn(is_classification=False)
@@ -174,16 +164,13 @@ class BoostedTreesRegressionTest(test_util.TensorFlowTestCase):
num_steps = 100
# Train for a few steps, and validate final checkpoint.
est.train(train_input_fn, steps=num_steps)
- self._assert_checkpoint(est.model_dir, 6)
+ self._assert_checkpoint(
+ est.model_dir, global_step=5, finalized_trees=1, attempted_layers=5)
predictions = list(est.predict(input_fn=predict_input_fn))
-
- self.assertEquals(5, len(predictions))
- self.assertAllClose([0.703549], predictions[0]['predictions'])
- self.assertAllClose([0.266539], predictions[1]['predictions'])
- self.assertAllClose([0.256479], predictions[2]['predictions'])
- self.assertAllClose([1.088732], predictions[3]['predictions'])
- self.assertAllClose([1.901732], predictions[4]['predictions'])
+ self.assertAllClose(
+ [[0.571619], [0.262821], [0.124549], [0.956801], [1.769801]],
+ [pred['predictions'] for pred in predictions])
class ModelFnTests(test_util.TensorFlowTestCase):