aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/estimator/canned/head.py7
-rw-r--r--tensorflow/python/estimator/canned/head_test.py8
2 files changed, 9 insertions, 6 deletions
diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py
index d2c5772483..80d109d927 100644
--- a/tensorflow/python/estimator/canned/head.py
+++ b/tensorflow/python/estimator/canned/head.py
@@ -200,8 +200,11 @@ def _check_labels(labels, expected_labels_dimension):
dim1 = static_shape[1]
if (dim1 is not None) and (dim1 != expected_labels_dimension):
raise ValueError(
- 'labels shape must be [batch_size, labels_dimension], got %s.' %
- (static_shape,))
+ 'Mismatched label shape. '
+ 'Classifier configured with n_classes=%s. Received %s. '
+ 'Suggested Fix: check your n_classes argument to the estimator '
+ 'and/or the shape of your label.' %
+ (expected_labels_dimension, dim1))
assert_dimension = check_ops.assert_equal(
expected_labels_dimension, labels_shape[1], message=err_msg)
with ops.control_dependencies([assert_dimension]):
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 23678013c6..fa3d5b44eb 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -139,7 +139,7 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
features = {'x': np.array(((42.,),))}
# Static shape.
- with self.assertRaisesRegexp(ValueError, 'labels shape'):
+ with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):
head.create_loss(
features=features,
mode=model_fn.ModeKeys.EVAL,
@@ -889,7 +889,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase):
logits_2x1 = np.array(((45.,), (41.,),))
# Static shape.
- with self.assertRaisesRegexp(ValueError, 'labels shape'):
+ with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):
head.create_loss(
features={'x': np.array(((42.,),))},
mode=model_fn.ModeKeys.EVAL,
@@ -1692,7 +1692,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
values_1d = np.array(((43.,), (44.,),))
# Static shape.
- with self.assertRaisesRegexp(ValueError, 'labels shape'):
+ with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):
head.create_loss(
features={'x': values_1d},
mode=model_fn.ModeKeys.EVAL,
@@ -1737,7 +1737,7 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase):
values_1d = np.array(((43.,), (44.,),))
# Static shape.
- with self.assertRaisesRegexp(ValueError, 'labels shape'):
+ with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):
head.create_loss(
features={'x': values_1d},
mode=model_fn.ModeKeys.TRAIN,