diff options
Diffstat (limited to 'tensorflow/python/estimator/canned/head_test.py')
-rw-r--r-- | tensorflow/python/estimator/canned/head_test.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 4497cd26f2..0a4ea7d81c 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -987,14 +987,12 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): spec.loss.eval() def test_multi_dim_train_weights_wrong_outer_dim(self): - """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 3].""" + """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 2].""" head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( n_classes=3, weight_column='weights') logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]], dtype=np.float32) labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64) - weights = np.array([[[1., 1.1, 1.2], [1.5, 1.6, 1.7]], - [[2., 2.1, 2.2], [2.5, 2.6, 2.7]]]) weights_placeholder = array_ops.placeholder(dtype=dtypes.float32) def _no_op_train_fn(loss): del loss @@ -1010,8 +1008,10 @@ class MultiClassHeadWithSoftmaxCrossEntropyLoss(test.TestCase): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 3\]'): - spec.loss.eval({weights_placeholder: weights}) + r'\[logits_shape: \]\s\[2 2 3\]\s\[weights_shape: \]\s\[2 2 2\]'): + spec.loss.eval({ + weights_placeholder: np.array([[[1., 1.1], [1.5, 1.6]], + [[2., 2.1], [2.5, 2.6]]])}) def test_multi_dim_weighted_eval(self): """Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2].""" |