aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch/model.py')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py74
1 files changed, 53 insertions, 21 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 0e8a56e6e9..2fbe72951a 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -29,7 +29,17 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import training_util
-def model_builder(features, labels, mode, params, config):
+class ModelBuilderOutputType(object):
+ MODEL_FN_OPS = 0
+ ESTIMATOR_SPEC = 1
+
+
+def model_builder(features,
+ labels,
+ mode,
+ params,
+ config,
+ output_type=ModelBuilderOutputType.MODEL_FN_OPS):
"""Multi-machine batch gradient descent tree model.
Args:
@@ -115,31 +125,53 @@ def model_builder(features, labels, mode, params, config):
return update_op
create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
- if use_core_libs and callable(create_estimator_spec_op):
- model_fn_ops = head.create_estimator_spec(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
- model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
- else:
- model_fn_ops = head.create_model_fn_ops(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
- if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
- model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
- gbdt_batch.LEAF_INDEX]
+
if num_trees:
if center_bias:
num_trees += 1
finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor()
- model_fn_ops.training_hooks.append(
+ training_hooks = [
trainer_hooks.StopAfterNTrees(num_trees, attempted_trees,
- finalized_trees))
+ finalized_trees)
+ ]
+
+ if output_type == ModelBuilderOutputType.MODEL_FN_OPS:
+ if use_core_libs and callable(create_estimator_spec_op):
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(
+ model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ if output_leaf_index and gbdt_batch.LEAF_INDEX in predictions_dict:
+ model_fn_ops.predictions[gbdt_batch.LEAF_INDEX] = predictions_dict[
+ gbdt_batch.LEAF_INDEX]
+
+ model_fn_ops.training_hooks.extend(training_hooks)
+ return model_fn_ops
+ elif output_type == ModelBuilderOutputType.ESTIMATOR_SPEC:
+ assert callable(create_estimator_spec_op)
+ estimator_spec = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+
+ estimator_spec = estimator_spec._replace(
+ training_hooks=training_hooks + list(estimator_spec.training_hooks))
+ return estimator_spec
+
return model_fn_ops