aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 16b2633b3a..06e5e9d9df 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -503,12 +503,6 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
x=x, input_fn=input_fn, batch_size=batch_size)
def _loss_vec(self, logits, target):
- # Check that we got int32/int64 for classification.
- if (not target.dtype.is_compatible_with(dtypes.int64) and
- not target.dtype.is_compatible_with(dtypes.int32)):
- raise ValueError("Target's dtype should be int32, int64 or compatible. "
- "Instead got %s." % target.dtype)
-
if self._n_classes == 2:
# sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
if len(target.get_shape()) == 1:
@@ -516,6 +510,11 @@ class DNNLinearCombinedClassifier(_DNNLinearCombinedBaseEstimator):
loss_vec = nn.sigmoid_cross_entropy_with_logits(
logits, math_ops.to_float(target))
else:
+ # Check that we got int32/int64 for classification.
+ if (not target.dtype.is_compatible_with(dtypes.int64) and
+ not target.dtype.is_compatible_with(dtypes.int32)):
+ raise ValueError("Target's dtype should be int32, int64 or compatible. "
+ "Instead got %s." % target.dtype)
# sparse_softmax_cross_entropy_with_logits requires [batch_size] target.
if len(target.get_shape()) == 2:
target = array_ops.squeeze(target, squeeze_dims=[1])