aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/canned/head_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/canned/head_test.py')
-rw-r--r--tensorflow/python/estimator/canned/head_test.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py
index 4497cd26f2..0a4ea7d81c 100644
--- a/tensorflow/python/estimator/canned/head_test.py
+++ b/tensorflow/python/estimator/canned/head_test.py
@@ -987,14 +987,12 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
spec.loss.eval()
def test_multi_dim_train_weights_wrong_outer_dim(self):
- """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 3]."""
+ """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 2]."""
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(
n_classes=3, weight_column='weights')
logits = np.array([[[10, 0, 0], [12, 0, 0]],
[[0, 10, 0], [0, 15, 0]]], dtype=np.float32)
labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)
- weights = np.array([[[1., 1.1, 1.2], [1.5, 1.6, 1.7]],
- [[2., 2.1, 2.2], [2.5, 2.6, 2.7]]])
weights_placeholder = array_ops.placeholder(dtype=dtypes.float32)
def _no_op_train_fn(loss):
del loss
@@ -1010,8 +1008,10 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase):
_initialize_variables(self, monitored_session.Scaffold())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
- r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 3\]'):
- spec.loss.eval({weights_placeholder: weights})
+ r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 2\]'):
+ spec.loss.eval({
+ weights_placeholder: np.array([[[1., 1.1], [1.5, 1.6]],
+ [[2., 2.1], [2.5, 2.6]]])})
def test_multi_dim_weighted_eval(self):
"""Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2]."""