diff options
author | 2016-06-06 12:11:35 -0800 | |
---|---|---|
committer | 2016-06-06 13:17:28 -0700 | |
commit | e70c452a18403b368ea845cfb654079386a00fd8 (patch) | |
tree | 62270f2e53dcfb268639134b6ee2614552ef6a21 /tensorflow | |
parent | b65da078ec8c0446cee95e2ab8e6806a989f9936 (diff) |
Allows probability as label for logistic regression in DNNLinearCombined
Change: 124169174
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 16b2633b3a..06e5e9d9df 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -503,12 +503,6 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator): x=x, input_fn=input_fn, batch_size=batch_size) def _loss_vec(self, logits, target): - # Check that we got int32/int64 for classification. - if (not target.dtype.is_compatible_with(dtypes.int64) and - not target.dtype.is_compatible_with(dtypes.int32)): - raise ValueError("Target's dtype should be int32, int64 or compatible. " - "Instead got %s." % target.dtype) - if self._n_classes == 2: # sigmoid_cross_entropy_with_logits requires [batch_size, 1] target. if len(target.get_shape()) == 1: @@ -516,6 +510,11 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator): loss_vec = nn.sigmoid_cross_entropy_with_logits( logits, math_ops.to_float(target)) else: + # Check that we got int32/int64 for classification. + if (not target.dtype.is_compatible_with(dtypes.int64) and + not target.dtype.is_compatible_with(dtypes.int32)): + raise ValueError("Target's dtype should be int32, int64 or compatible. " + "Instead got %s." % target.dtype) # sparse_softmax_cross_entropy_with_logits requires [batch_size] target. if len(target.get_shape()) == 2: target = array_ops.squeeze(target, squeeze_dims=[1]) |