aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/boosted_trees/estimator_batch
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-11 15:44:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-11 15:47:16 -0700
commitacad7022b09b090da0684f209ac8d0feb1c986a2 (patch)
tree251ea948e61c58f2177c68dd61b61bec65961940 /tensorflow/contrib/boosted_trees/estimator_batch
parent9ce7791be6980932c249832dc23d464c1b736cc4 (diff)
Adding support of core feature columns and losses to gradient boosted trees estimators.
PiperOrigin-RevId: 192521398
Diffstat (limited to 'tensorflow/contrib/boosted_trees/estimator_batch')
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/BUILD33
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py5
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py96
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py19
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py138
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py71
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py27
7 files changed, 293 insertions, 96 deletions
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index 17e20c4b31..0f65881aee 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -51,6 +51,18 @@ py_library(
],
)
+py_library(
+ name = "estimator_utils",
+ srcs = ["estimator_utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/learn",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
py_test(
name = "trainer_hooks_test",
size = "small",
@@ -118,6 +130,7 @@ py_library(
srcs = ["estimator.py"],
srcs_version = "PY2AND3",
deps = [
+ ":estimator_utils",
":model",
"//tensorflow/contrib/boosted_trees:losses",
"//tensorflow/contrib/learn",
@@ -130,6 +143,7 @@ py_library(
srcs = ["dnn_tree_combined_estimator.py"],
srcs_version = "PY2AND3",
deps = [
+ ":estimator_utils",
":trainer_hooks",
"//tensorflow/contrib/boosted_trees:gbdt_batch",
"//tensorflow/contrib/boosted_trees:model_ops_py",
@@ -159,3 +173,22 @@ py_test(
"//tensorflow/python:framework_for_generated_wrappers",
],
)
+
+py_test(
+ name = "estimator_test",
+ size = "medium",
+ srcs = ["estimator_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_gpu",
+ "no_pip_gpu",
+ "notsan",
+ ],
+ deps = [
+ ":estimator",
+ "//tensorflow/contrib/boosted_trees:gbdt_batch",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
index d9b0d89a03..62f1f4122b 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py
@@ -39,7 +39,8 @@ _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d"
def make_custom_export_strategy(name,
convert_fn,
feature_columns,
- export_input_fn):
+ export_input_fn,
+ use_core_columns=False):
"""Makes custom exporter of GTFlow tree format.
Args:
@@ -58,7 +59,7 @@ def make_custom_export_strategy(name,
input_fn = export_input_fn()
(sorted_feature_names, dense_floats, sparse_float_indices, _, _,
sparse_int_indices, _, _) = gbdt_batch.extract_features(
- input_fn.features, feature_columns)
+ input_fn.features, feature_columns, use_core_columns)
def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
"""A wrapper to export to SavedModel, and convert it to other formats."""
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
index 2e7b8cba05..449c130b2d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py
@@ -19,25 +19,19 @@ logits of the DNN. The input layer of the DNN (including the embeddings learned
over sparse features) can optionally be provided to the boosted trees as
an additional input feature.
"""
-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from tensorflow.contrib import layers
+from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
from tensorflow.contrib.boosted_trees.python.ops import model_ops
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
from tensorflow.contrib.layers.python.layers import optimizers
-from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
-from tensorflow.contrib.learn.python.learn.estimators import model_fn
-from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn_lib
-from tensorflow.contrib.learn.python.learn.estimators import prediction_key
-from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.estimator.export import export_output
from tensorflow.python.feature_column import feature_column as feature_column_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
@@ -48,56 +42,8 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.summary import summary
from tensorflow.python.training import training_util
-
_DNN_LEARNING_RATE = 0.001
-_CORE_MODE_TO_CONTRIB_MODE_ = {
- model_fn_lib.ModeKeys.TRAIN: contrib_model_fn_lib.ModeKeys.TRAIN,
- model_fn_lib.ModeKeys.EVAL: contrib_model_fn_lib.ModeKeys.EVAL,
- model_fn_lib.ModeKeys.PREDICT: contrib_model_fn_lib.ModeKeys.INFER
-}
-
-
-def _core_mode_to_contrib_mode(mode):
- return _CORE_MODE_TO_CONTRIB_MODE_[mode]
-
-
-def _export_outputs_to_output_alternatives(export_outputs):
- """Converts EstimatorSpec.export_outputs to output_alternatives.
-
- Args:
- export_outputs: export_outputs created by create_estimator_spec.
- Returns:
- converted output_alternatives.
- """
- output = dict()
- if export_outputs is not None:
- for key, value in export_outputs.items():
- if isinstance(value, export_output.ClassificationOutput):
- exported_predictions = {
- prediction_key.PredictionKey.SCORES: value.scores,
- prediction_key.PredictionKey.CLASSES: value.classes
- }
- output[key] = (constants.ProblemType.CLASSIFICATION,
- exported_predictions)
- return output
- return None
-
-
-def _estimator_spec_to_model_fn_ops(estimator_spec, is_regression):
- alternatives = []
- if not is_regression:
- _export_outputs_to_output_alternatives(estimator_spec.export_outputs)
-
- return model_fn.ModelFnOps(
- mode=_core_mode_to_contrib_mode(estimator_spec.mode),
- predictions=estimator_spec.predictions,
- loss=estimator_spec.loss,
- train_op=estimator_spec.train_op,
- eval_metric_ops=estimator_spec.eval_metric_ops,
- output_alternatives=alternatives)
-
-
def _get_optimizer(optimizer):
if callable(optimizer):
return optimizer()
@@ -128,8 +74,7 @@ def _dnn_tree_combined_model_fn(features,
dnn_steps_to_train=10000,
tree_feature_columns=None,
tree_center_bias=False,
- use_core_versions=False,
- is_regression=False):
+ use_core_versions=False):
"""DNN and GBDT combined model_fn.
Args:
@@ -169,7 +114,6 @@ def _dnn_tree_combined_model_fn(features,
first fitting the bias.
use_core_versions: Whether feature columns and loss are from the core (as
opposed to contrib) version of tensorflow.
- is_regression: Whether the problem is regression or not.
Returns:
A `ModelFnOps` object.
@@ -305,8 +249,8 @@ def _dnn_tree_combined_model_fn(features,
labels=labels,
train_op_fn=_dnn_train_op_fn,
logits=dnn_logits)
- dnn_train_op = _estimator_spec_to_model_fn_ops(dnn_train_op,
- is_regression).train_op
+ dnn_train_op = estimator_utils.estimator_spec_to_model_fn_ops(
+ dnn_train_op).train_op
tree_train_op = head.create_estimator_spec(
features=tree_features,
@@ -314,10 +258,10 @@ def _dnn_tree_combined_model_fn(features,
labels=labels,
train_op_fn=_tree_train_op_fn,
logits=tree_train_logits)
- tree_train_op = _estimator_spec_to_model_fn_ops(tree_train_op,
- is_regression).train_op
+ tree_train_op = estimator_utils.estimator_spec_to_model_fn_ops(
+ tree_train_op).train_op
- model_fn_ops = _estimator_spec_to_model_fn_ops(model_fn_ops, is_regression)
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
else:
model_fn_ops = head.create_model_fn_ops(
features=features,
@@ -529,26 +473,12 @@ class DNNBoostedTreeCombinedRegressor(estimator.Estimator):
def _model_fn(features, labels, mode, config):
return _dnn_tree_combined_model_fn(
- features,
- labels,
- mode,
- head,
- dnn_hidden_units,
- dnn_feature_columns,
- tree_learner_config,
- num_trees,
- tree_examples_per_layer,
- config,
- dnn_optimizer,
- dnn_activation_fn,
- dnn_dropout,
- dnn_input_layer_partitioner,
- dnn_input_layer_to_tree,
- dnn_steps_to_train,
- tree_feature_columns,
- tree_center_bias,
- use_core_versions,
- is_regression=True)
+ features, labels, mode, head, dnn_hidden_units, dnn_feature_columns,
+ tree_learner_config, num_trees, tree_examples_per_layer, config,
+ dnn_optimizer, dnn_activation_fn, dnn_dropout,
+ dnn_input_layer_partitioner, dnn_input_layer_to_tree,
+ dnn_steps_to_train, tree_feature_columns, tree_center_bias,
+ use_core_versions)
super(DNNBoostedTreeCombinedRegressor, self).__init__(
model_fn=_model_fn, model_dir=model_dir,
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 70454aa6db..89d0d611d2 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -40,7 +40,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
label_keys=None,
feature_engineering_fn=None,
logits_modifier_function=None,
- center_bias=True):
+ center_bias=True,
+ use_core_libs=False):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
@@ -63,7 +64,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
-
+ use_core_libs: Whether feature columns and loss are from the core (as
+ opposed to contrib) version of tensorflow.
Raises:
ValueError: If learner_config is not valid.
"""
@@ -99,6 +101,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
'examples_per_layer': examples_per_layer,
'center_bias': center_bias,
'logits_modifier_function': logits_modifier_function,
+ 'use_core_libs': use_core_libs,
},
model_dir=model_dir,
config=config,
@@ -120,7 +123,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
config=None,
feature_engineering_fn=None,
logits_modifier_function=None,
- center_bias=True):
+ center_bias=True,
+ use_core_libs=False):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
@@ -145,6 +149,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
+ use_core_libs: Whether feature columns and loss are from the core (as
+ opposed to contrib) version of tensorflow.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -166,6 +172,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
'examples_per_layer': examples_per_layer,
'logits_modifier_function': logits_modifier_function,
'center_bias': center_bias,
+ 'use_core_libs': use_core_libs,
},
model_dir=model_dir,
config=config,
@@ -189,7 +196,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
config=None,
feature_engineering_fn=None,
logits_modifier_function=None,
- center_bias=True):
+ center_bias=True,
+ use_core_libs=False):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
@@ -210,6 +218,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
logits_modifier_function: A modifier function for the logits.
center_bias: Whether a separate tree should be created for first fitting
the bias.
+ use_core_libs: Whether feature columns and loss are from the core (as
+ opposed to contrib) version of tensorflow.
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
@@ -222,6 +232,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
'examples_per_layer': examples_per_layer,
'logits_modifier_function': logits_modifier_function,
'center_bias': center_bias,
+ 'use_core_libs': use_core_libs,
},
model_dir=model_dir,
config=config,
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
new file mode 100644
index 0000000000..0d58317bd5
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -0,0 +1,138 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for GBDT estimator."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import tempfile
+from tensorflow.contrib.boosted_trees.estimator_batch import estimator
+from tensorflow.contrib.boosted_trees.proto import learner_pb2
+from tensorflow.contrib.layers.python.layers import feature_column as contrib_feature_column
+from tensorflow.contrib.learn.python.learn.estimators import run_config
+from tensorflow.python.estimator.canned import head as head_lib
+from tensorflow.python.feature_column import feature_column_lib as core_feature_column
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops.losses import losses
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import googletest
+
+
+def _train_input_fn():
+ features = {"x": constant_op.constant([[2.], [1.], [1.]])}
+ label = constant_op.constant([[1], [0], [0]], dtype=dtypes.int32)
+ return features, label
+
+
+def _eval_input_fn():
+ features = {"x": constant_op.constant([[1.], [2.], [2.]])}
+ label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32)
+ return features, label
+
+
+class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
+
+ def setUp(self):
+ self._export_dir_base = tempfile.mkdtemp() + "export/"
+ gfile.MkDir(self._export_dir_base)
+
+ def testFitAndEvaluateDontThrowException(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[contrib_feature_column.real_valued_column("x")])
+
+ classifier.fit(input_fn=_train_input_fn, steps=15)
+ classifier.evaluate(input_fn=_eval_input_fn, steps=1)
+ classifier.export(self._export_dir_base)
+
+ def testFitAndEvaluateDontThrowExceptionWithCoreForEstimator(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ # Use core head
+ head_fn = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
+
+ model = estimator.GradientBoostedDecisionTreeEstimator(
+ head=head_fn,
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[core_feature_column.numeric_column("x")],
+ use_core_libs=True)
+
+ model.fit(input_fn=_train_input_fn, steps=15)
+ model.evaluate(input_fn=_eval_input_fn, steps=1)
+ model.export(self._export_dir_base)
+
+ def testFitAndEvaluateDontThrowExceptionWithCoreForClassifier(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ classifier = estimator.GradientBoostedDecisionTreeClassifier(
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[core_feature_column.numeric_column("x")],
+ use_core_libs=True)
+
+ classifier.fit(input_fn=_train_input_fn, steps=15)
+ classifier.evaluate(input_fn=_eval_input_fn, steps=1)
+ classifier.export(self._export_dir_base)
+
+ def testFitAndEvaluateDontThrowExceptionWithCoreForRegressor(self):
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.constraints.max_tree_depth = 1
+ model_dir = tempfile.mkdtemp()
+ config = run_config.RunConfig()
+
+ regressor = estimator.GradientBoostedDecisionTreeRegressor(
+ learner_config=learner_config,
+ num_trees=1,
+ examples_per_layer=3,
+ model_dir=model_dir,
+ config=config,
+ feature_columns=[core_feature_column.numeric_column("x")],
+ use_core_libs=True)
+
+ regressor.fit(input_fn=_train_input_fn, steps=15)
+ regressor.evaluate(input_fn=_eval_input_fn, steps=1)
+ regressor.export(self._export_dir_base)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py
new file mode 100644
index 0000000000..c9cf4ae25a
--- /dev/null
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_utils.py
@@ -0,0 +1,71 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for converting between core and contrib feature columns."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.learn.python.learn.estimators import constants
+from tensorflow.contrib.learn.python.learn.estimators import model_fn
+from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn_lib
+from tensorflow.contrib.learn.python.learn.estimators import prediction_key
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export_output
+
+_CORE_MODE_TO_CONTRIB_MODE_ = {
+ model_fn_lib.ModeKeys.TRAIN: contrib_model_fn_lib.ModeKeys.TRAIN,
+ model_fn_lib.ModeKeys.EVAL: contrib_model_fn_lib.ModeKeys.EVAL,
+ model_fn_lib.ModeKeys.PREDICT: contrib_model_fn_lib.ModeKeys.INFER
+}
+
+
+def _core_mode_to_contrib_mode(mode):
+ return _CORE_MODE_TO_CONTRIB_MODE_[mode]
+
+
+def _export_outputs_to_output_alternatives(export_outputs):
+ """Converts EstimatorSpec.export_outputs to output_alternatives.
+
+ Args:
+ export_outputs: export_outputs created by create_estimator_spec.
+ Returns:
+ converted output_alternatives.
+ """
+ output = dict()
+ if export_outputs is not None:
+ for key, value in export_outputs.items():
+ if isinstance(value, export_output.ClassificationOutput):
+ exported_predictions = {
+ prediction_key.PredictionKey.SCORES: value.scores,
+ prediction_key.PredictionKey.CLASSES: value.classes
+ }
+ output[key] = (constants.ProblemType.CLASSIFICATION,
+ exported_predictions)
+ return output
+ return None
+
+
+def estimator_spec_to_model_fn_ops(estimator_spec):
+ alternatives = _export_outputs_to_output_alternatives(
+ estimator_spec.export_outputs)
+
+ return model_fn.ModelFnOps(
+ mode=_core_mode_to_contrib_mode(estimator_spec.mode),
+ predictions=estimator_spec.predictions,
+ loss=estimator_spec.loss,
+ train_op=estimator_spec.train_op,
+ eval_metric_ops=estimator_spec.eval_metric_ops,
+ output_alternatives=alternatives)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index c6455a7ea3..15ab6d8145 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import copy
+from tensorflow.contrib.boosted_trees.estimator_batch import estimator_utils
from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks
from tensorflow.contrib.boosted_trees.python.ops import model_ops
from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch
@@ -60,6 +61,7 @@ def model_builder(features, labels, mode, params, config):
feature_columns = params["feature_columns"]
weight_column_name = params["weight_column_name"]
num_trees = params["num_trees"]
+ use_core_libs = params["use_core_libs"]
logits_modifier_function = params["logits_modifier_function"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -93,7 +95,8 @@ def model_builder(features, labels, mode, params, config):
learner_config=learner_config,
feature_columns=feature_columns,
logits_dimension=head.logits_dimension,
- features=training_features)
+ features=training_features,
+ use_core_columns=use_core_libs)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]
@@ -108,12 +111,22 @@ def model_builder(features, labels, mode, params, config):
update_op = state_ops.assign_add(global_step, 1).op
return update_op
- model_fn_ops = head.create_model_fn_ops(
- features=features,
- mode=mode,
- labels=labels,
- train_op_fn=_train_op_fn,
- logits=logits)
+ create_estimator_spec_op = getattr(head, "create_estimator_spec", None)
+ if use_core_libs and callable(create_estimator_spec_op):
+ model_fn_ops = head.create_estimator_spec(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
+ model_fn_ops = estimator_utils.estimator_spec_to_model_fn_ops(model_fn_ops)
+ else:
+ model_fn_ops = head.create_model_fn_ops(
+ features=features,
+ mode=mode,
+ labels=labels,
+ train_op_fn=_train_op_fn,
+ logits=logits)
if num_trees:
if center_bias:
num_trees += 1