aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-01 17:34:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 17:39:02 -0700
commitc17146df687432cfb58a964364931cd1e5631471 (patch)
tree94d40d862eef055c9c537e299d3998a0c6975835 /tensorflow/contrib/boosted_trees
parent6e9c1b57087e15ea850b20bace7881fe95a86854 (diff)
Allow to set global step to a particular value, after the early stopping triggered by the number of trees fired.
PiperOrigin-RevId: 207024504
Diffstat (limited to 'tensorflow/contrib/boosted_trees')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py48
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py113
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py46
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py16
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py17
5 files changed, 189 insertions, 51 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
index dbfa69edcb..194a5c8754 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
@@ -86,7 +86,8 @@ def _dnn_tree_combined_model_fn(
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
use_core_versions=False,
- output_type=model.ModelBuilderOutputType.MODEL_FN_OPS):
+ output_type=model.ModelBuilderOutputType.MODEL_FN_OPS,
+ override_global_step_value=None):
"""DNN and GBDT combined model_fn.
Args:
@@ -135,6 +136,12 @@ def _dnn_tree_combined_model_fn(
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
+ (new interface).
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
Returns:
A `ModelFnOps` object.
@@ -350,7 +357,8 @@ def _dnn_tree_combined_model_fn(
trainer_hooks.SwitchTrainOp(dnn_train_op, dnn_steps_to_train,
tree_train_op),
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees)
+ finalized_trees,
+ override_global_step_value)
])
return model_fn_ops
@@ -378,7 +386,8 @@ def _dnn_tree_combined_model_fn(
trainer_hooks.SwitchTrainOp(dnn_spec.train_op, dnn_steps_to_train,
tree_spec.train_op),
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees)
+ finalized_trees,
+ override_global_step_value)
]
fusion_spec = fusion_spec._replace(training_hooks=training_hooks +
list(fusion_spec.training_hooks))
@@ -411,7 +420,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+ use_core_versions=False,
+ override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedClassifier instance.
Args:
@@ -467,6 +477,10 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
"""
head = head_lib.multi_class_head(
n_classes=n_classes,
@@ -497,7 +511,8 @@ class DNNBoostedTreeCombinedClassifier(estimator.Estimator):
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
- use_core_versions=use_core_versions)
+ use_core_versions=use_core_versions,
+ override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedClassifier, self).__init__(
model_fn=_model_fn,
@@ -531,7 +546,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+ use_core_versions=False,
+ override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedRegressor instance.
Args:
@@ -587,6 +603,10 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -622,7 +642,8 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
- use_core_versions=use_core_versions)
+ use_core_versions=use_core_versions,
+ override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedRegressor, self).__init__(
model_fn=_model_fn,
@@ -657,7 +678,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
tree_feature_columns=None,
tree_center_bias=False,
dnn_to_tree_distillation_param=None,
- use_core_versions=False):
+ use_core_versions=False,
+ override_global_step_value=None):
"""Initializes a DNNBoostedTreeCombinedEstimator instance.
Args:
@@ -708,6 +730,10 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
will be set to True.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
"""
def _model_fn(features, labels, mode, config):
@@ -732,7 +758,8 @@ class DNNBoostedTreeCombinedEstimator(estimator.Estimator):
tree_feature_columns=tree_feature_columns,
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
- use_core_versions=use_core_versions)
+ use_core_versions=use_core_versions,
+ override_global_step_value=override_global_step_value)
super(DNNBoostedTreeCombinedEstimator, self).__init__(
model_fn=_model_fn,
@@ -832,7 +859,8 @@ class CoreDNNBoostedTreeCombinedEstimator(core_estimator.Estimator):
tree_center_bias=tree_center_bias,
dnn_to_tree_distillation_param=dnn_to_tree_distillation_param,
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC,
- use_core_versions=True)
+ use_core_versions=True,
+ override_global_step_value=None)
super(CoreDNNBoostedTreeCombinedEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 2df879f924..2fa3db1e8d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -49,7 +49,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
logits_modifier_function=None,
center_bias=True,
use_core_libs=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ override_global_step_value=None):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
@@ -83,6 +84,14 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
for result_dict in result_iter:
# access leaf index list by result_dict["leaf_index"]
# which contains one leaf index per tree
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This should be used to reset global
+ step to a number > number of steps used to train the current ensemble.
+ For example, the usual way is to train a number of trees and set a very
+ large number of training steps. When the training is done (number of
+ trees were trained), this parameter can be used to set the global step
+ to a large value, making it look like that number of training steps ran.
+ If None, no override of global step will happen.
Raises:
ValueError: If learner_config is not valid.
@@ -123,6 +132,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
+ 'override_global_step_value': override_global_step_value
},
model_dir=model_dir,
config=config,
@@ -146,7 +156,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
logits_modifier_function=None,
center_bias=True,
use_core_libs=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ override_global_step_value=None):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
@@ -180,6 +191,14 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
for example_prediction_result in result_dict:
# access leaf index list by example_prediction_result["leaf_index"]
# which contains one leaf index per tree
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This should be used to reset global
+ step to a number > number of steps used to train the current ensemble.
+ For example, the usual way is to train a number of trees and set a very
+ large number of training steps. When the training is done (number of
+ trees were trained), this parameter can be used to set the global step
+ to a large value, making it look like that number of training steps ran.
+ If None, no override of global step will happen.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -203,6 +222,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
+ 'override_global_step_value': override_global_step_value
},
model_dir=model_dir,
config=config,
@@ -228,7 +248,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
logits_modifier_function=None,
center_bias=True,
use_core_libs=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ override_global_step_value=None):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
@@ -258,6 +279,14 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
for example_prediction_result in result_dict:
# access leaf index list by example_prediction_result["leaf_index"]
# which contains one leaf index per tree
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This should be used to reset global
+ step to a number > number of steps used to train the current ensemble.
+ For example, the usual way is to train a number of trees and set a very
+ large number of training steps. When the training is done (number of
+ trees were trained), this parameter can be used to set the global step
+ to a large value, making it look like that number of training steps ran.
+ If None, no override of global step will happen.
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
@@ -272,6 +301,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
+ 'override_global_step_value': override_global_step_value
},
model_dir=model_dir,
config=config,
@@ -281,24 +311,23 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
class GradientBoostedDecisionTreeRanker(estimator.Estimator):
"""A ranking estimator using gradient boosted decision trees."""
- def __init__(
- self,
- learner_config,
- examples_per_layer,
- head,
- ranking_model_pair_keys,
- num_trees=None,
- feature_columns=None,
- weight_column_name=None,
- model_dir=None,
- config=None,
- label_keys=None,
- feature_engineering_fn=None,
- logits_modifier_function=None,
- center_bias=False,
- use_core_libs=False,
- output_leaf_index=False,
- ):
+ def __init__(self,
+ learner_config,
+ examples_per_layer,
+ head,
+ ranking_model_pair_keys,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ feature_engineering_fn=None,
+ logits_modifier_function=None,
+ center_bias=False,
+ use_core_libs=False,
+ output_leaf_index=False,
+ override_global_step_value=None):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -338,7 +367,14 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
for result_dict in result_iter:
# access leaf index list by result_dict["leaf_index"]
# which contains one leaf index per tree
-
+ override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This should be used to reset global
+ step to a number > number of steps used to train the current ensemble.
+ For example, the usual way is to train a number of trees and set a very
+ large number of training steps. When the training is done (number of
+ trees were trained), this parameter can be used to set the global step
+ to a large value, making it look like that number of training steps ran.
+ If None, no override of global step will happen.
Raises:
ValueError: If learner_config is not valid.
"""
@@ -357,6 +393,7 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
+ 'override_global_step_value': override_global_step_value
},
model_dir=model_dir,
config=config,
@@ -435,6 +472,7 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
+ 'override_global_step_value': None
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
@@ -445,22 +483,20 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
"""A ranking estimator using gradient boosted decision trees."""
- def __init__(
- self,
- learner_config,
- examples_per_layer,
- head,
- ranking_model_pair_keys,
- num_trees=None,
- feature_columns=None,
- weight_column_name=None,
- model_dir=None,
- config=None,
- label_keys=None,
- logits_modifier_function=None,
- center_bias=False,
- output_leaf_index=False,
- ):
+ def __init__(self,
+ learner_config,
+ examples_per_layer,
+ head,
+ ranking_model_pair_keys,
+ num_trees=None,
+ feature_columns=None,
+ weight_column_name=None,
+ model_dir=None,
+ config=None,
+ label_keys=None,
+ logits_modifier_function=None,
+ center_bias=False,
+ output_leaf_index=False):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -519,6 +555,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
+ 'override_global_step_value': None
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index 9e9febbbef..83ef87c6fd 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -25,10 +25,12 @@ from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.feature_column import feature_column_lib as core_feature_column
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import googletest
+from tensorflow.python.training import checkpoint_utils
def _train_input_fn():
@@ -68,6 +70,10 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
self._export_dir_base = tempfile.mkdtemp() + "export/"
gfile.MkDir(self._export_dir_base)
+ def _assert_checkpoint(self, model_dir, global_step):
+ reader = checkpoint_utils.load_checkpoint(model_dir)
+ self.assertEqual(global_step, reader.get_tensor(ops.GraphKeys.GLOBAL_STEP))
+
def testFitAndEvaluateDontThrowException(self):
learner_config = learner_pb2.LearnerConfig()
learner_config.num_classes = 2
@@ -202,6 +208,46 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
model.evaluate(input_fn=_ranking_train_input_fn, steps=1)
model.predict(input_fn=_infer_ranking_train_input_fn)
+ def testDoesNotOverrideGlobalSteps(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 2
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[contrib_feature_column.real_valued_column("x")],
+ output_leaf_index=False)
+
+ classifier.fit(input_fn=_train_input_fn, steps=15)
+ # When no override of global steps, 5 steps were used.
+ self._assert_checkpoint(classifier.model_dir, global_step=5)
+
+ def testOverridesGlobalSteps(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 2
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[contrib_feature_column.real_valued_column("x")],
+ output_leaf_index=False,
+ override_global_step_value=10000000)
+
+ classifier.fit(input_fn=_train_input_fn, steps=15)
+ self._assert_checkpoint(classifier.model_dir, global_step=10000000)
+
class CoreGradientBoostedDecisionTreeEstimators(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 161cc42cb0..04b46c3483 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -58,6 +58,10 @@ def model_builder(features,
* weight_column_name: The name of weight column.
* center_bias: Whether a separate tree should be created for first fitting
the bias.
+ * override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
config: `RunConfig` of the estimator.
output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
(new interface).
@@ -76,6 +80,7 @@ def model_builder(features,
use_core_libs = params["use_core_libs"]
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
+ override_global_step_value = params.get("override_global_step_value", None)
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -136,7 +141,8 @@ def model_builder(features,
finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees))
+ finalized_trees,
+ override_global_step_value))
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
if use_core_libs and callable(create_estimator_spec_op):
@@ -206,6 +212,10 @@ def ranking_model_builder(features,
for left and right part of the training pairs for ranking. For example,
for an Example with features "a.f1" and "b.f1", the keys would be
("a", "b").
+ * override_global_step_value: If after the training is done, global step
+ value must be reset to this value. This is particularly useful for hyper
+ parameter tuning, which can't recognize early stopping due to the number
+ of trees. If None, no override of global step will happen.
config: `RunConfig` of the estimator.
output_type: Whether to return ModelFnOps (old interface) or EstimatorSpec
(new interface).
@@ -226,6 +236,7 @@ def ranking_model_builder(features,
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
ranking_model_pair_keys = params["ranking_model_pair_keys"]
+ override_global_step_value = params.get("override_global_step_value", None)
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -347,7 +358,8 @@ def ranking_model_builder(features,
gbdt_model_main.get_number_of_trees_tensor())
training_hooks.append(
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees))
+ finalized_trees,
+ override_global_step_value))
if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
if use_core_libs and callable(create_estimator_spec_op):
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py
index 2e4151cac4..cb9f020b88 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py
@@ -25,6 +25,7 @@ from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArg
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training_util
from tensorflow.python.training.summary_io import SummaryWriterCache
@@ -150,12 +151,23 @@ class FeedFnHook(session_run_hook.SessionRunHook):
class StopAfterNTrees(session_run_hook.SessionRunHook):
"""Stop training after building N full trees."""
- def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor):
+ def __init__(self, n, num_attempted_trees_tensor, num_finalized_trees_tensor,
+ override_global_step_value):
self._num_trees = n
# num_attempted_trees_tensor and num_finalized_trees_tensor are both
# tensors.
self._num_attempted_trees_tensor = num_attempted_trees_tensor
self._num_finalized_trees_tensor = num_finalized_trees_tensor
+ self._override_global_step_value = override_global_step_value
+
+ def begin(self):
+ self._global_step_tensor = training_util.get_global_step()
+ if self._global_step_tensor is None:
+ raise RuntimeError("Global step should be created.")
+
+ if self._override_global_step_value is not None:
+ self._override_global_step_op = state_ops.assign(
+ self._global_step_tensor, self._override_global_step_value)
def before_run(self, run_context):
del run_context # unused by StopTrainingAfterNTrees.
@@ -175,6 +187,9 @@ class StopAfterNTrees(session_run_hook.SessionRunHook):
num_attempted_trees > 2 * self._num_trees):
logging.info("Requesting stop since we have reached %d trees.",
num_finalized_trees)
+ if self._override_global_step_value is not None:
+ logging.info("Overriding global steps value.")
+ run_context.session.run(self._override_global_step_op)
run_context.request_stop()