aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensor_forest
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-16 16:54:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 16:58:50 -0700
commit363839e5b53ce42758a5127084e319583b5f3c48 (patch)
treeb2eb69f9470246907333f7eaeb5338eedd53d4e0 /tensorflow/contrib/tensor_forest
parent288991b17d75330476ef7d61875d345dc9aa233f (diff)
Use binary classification head for num_classes = 2.
PiperOrigin-RevId: 209074011
Diffstat (limited to 'tensorflow/contrib/tensor_forest')
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD1
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py20
2 files changed, 14 insertions, 7 deletions
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 164f3e58e6..22d6e499d2 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -515,6 +515,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":client_lib",
+ "//tensorflow/contrib/estimator:head",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index 8fa0b3ada9..db970deff5 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib import layers
+from tensorflow.contrib.estimator.python.estimator import head as core_head_lib
from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import head as head_lib
@@ -25,7 +26,6 @@ from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_f
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.python.estimator import estimator as core_estimator
-from tensorflow.python.estimator.canned import head as core_head_lib
from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import ops
@@ -130,17 +130,23 @@ def _get_default_head(params, weights_name, output_type, name=None):
head_name=name)
else:
if params.regression:
- return core_head_lib._regression_head( # pylint:disable=protected-access
+ return core_head_lib.regression_head(
weight_column=weights_name,
label_dimension=params.num_outputs,
name=name,
loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
else:
- return core_head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
- n_classes=params.num_classes,
- weight_column=weights_name,
- name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ if params.num_classes == 2:
+ return core_head_lib.binary_classification_head(
+ weight_column=weights_name,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ else:
+ return core_head_lib.multi_class_head(
+ n_classes=params.num_classes,
+ weight_column=weights_name,
+ name=name,
+ loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
def get_model_fn(params,
graph_builder_class,