diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-16 16:54:05 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-16 16:58:50 -0700 |
commit | 363839e5b53ce42758a5127084e319583b5f3c48 (patch) | |
tree | b2eb69f9470246907333f7eaeb5338eedd53d4e0 /tensorflow/contrib/tensor_forest | |
parent | 288991b17d75330476ef7d61875d345dc9aa233f (diff) |
Use binary classification head for num_classes = 2.
PiperOrigin-RevId: 209074011
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r-- | tensorflow/contrib/tensor_forest/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/tensor_forest/client/random_forest.py | 20 |
2 files changed, 14 insertions, 7 deletions
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 164f3e58e6..22d6e499d2 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -515,6 +515,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":client_lib", + "//tensorflow/contrib/estimator:head", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", "//tensorflow/python:array_ops", diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py index 8fa0b3ada9..db970deff5 100644 --- a/tensorflow/contrib/tensor_forest/client/random_forest.py +++ b/tensorflow/contrib/tensor_forest/client/random_forest.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import layers +from tensorflow.contrib.estimator.python.estimator import head as core_head_lib from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -25,7 +26,6 @@ from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_f from tensorflow.contrib.tensor_forest.client import eval_metrics from tensorflow.contrib.tensor_forest.python import tensor_forest from tensorflow.python.estimator import estimator as core_estimator -from tensorflow.python.estimator.canned import head as core_head_lib from tensorflow.python.estimator.export.export_output import PredictOutput from tensorflow.python.feature_column import feature_column as fc_core from tensorflow.python.framework import ops @@ -130,17 +130,23 @@ def _get_default_head(params, weights_name, output_type, name=None): head_name=name) else: if params.regression: - return core_head_lib._regression_head( # pylint:disable=protected-access + return core_head_lib.regression_head( weight_column=weights_name, label_dimension=params.num_outputs, name=name, loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) else: - return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access - n_classes=params.num_classes, - weight_column=weights_name, - name=name, - loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + if params.num_classes == 2: + return core_head_lib.binary_classification_head( + weight_column=weights_name, + name=name, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) + else: + return core_head_lib.multi_class_head( + n_classes=params.num_classes, + weight_column=weights_name, + name=name, + loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS) def get_model_fn(params, graph_builder_class, |