aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-06-06 12:11:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-06 13:17:28 -0700
commite70c452a18403b368ea845cfb654079386a00fd8 (patch)
tree62270f2e53dcfb268639134b6ee2614552ef6a21 /tensorflow
parentb65da078ec8c0446cee95e2ab8e6806a989f9936 (diff)
Allows probability as label for logistic regression in DNNLinearCombined
Change: 124169174
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 16b2633b3a..06e5e9d9df 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -503,12 +503,6 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
x=x, input_fn=input_fn, batch_size=batch_size)
def _loss_vec(self, logits, target):
- # Check that we got int32/int64 for classification.
- if (not target.dtype.is_compatible_with(dtypes.int64) and
- not target.dtype.is_compatible_with(dtypes.int32)):
- raise ValueError("Target's dtype should be int32, int64 or compatible. "
- "Instead got %s." % target.dtype)
-
if self._n_classes == 2:
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
if len(target.get_shape()) == 1:
@@ -516,6 +510,11 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
loss_vec = nn.sigmoid_cross_entropy_with_logits(
logits, math_ops.to_float(target))
else:
+ # Check that we got int32/int64 for classification.
+ if (not target.dtype.is_compatible_with(dtypes.int64) and
+ not target.dtype.is_compatible_with(dtypes.int32)):
+ raise ValueError("Target's dtype should be int32, int64 or compatible. "
+ "Instead got %s." % target.dtype)
# sparse_softmax_cross_entropy_with_logits requires [batch_size] target.
if len(target.get_shape()) == 2:
target = array_ops.squeeze(target, squeeze_dims=[1])