aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-23 11:02:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-25 02:43:16 -0700
commit0526238462dc39c7b90733102583eea55a0d62bc (patch)
tree2fa2fe2069a5b963230076db97fc6beb08492fd4
parent651ebf95adf88924c9dfb9cddac3d96a30dffed3 (diff)
Changes loss_reduction default to SUM_OVER_BATCH_SIZE for multi_label_head.
PiperOrigin-RevId: 190244159
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head.py7
-rw-r--r--tensorflow/contrib/estimator/python/estimator/head_test.py108
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head_test.py50
3 files changed, 81 insertions, 84 deletions
diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py
index 42e1b7b68c..74da2cbb3f 100644
--- a/tensorflow/contrib/estimator/python/estimator/head.py
+++ b/tensorflow/contrib/estimator/python/estimator/head.py
@@ -304,7 +304,7 @@ def multi_label_head(n_classes,
weight_column=None,
thresholds=None,
label_vocabulary=None,
- loss_reduction=losses.Reduction.SUM,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
name=None):
"""Creates a `_Head` for multi-label classification.
@@ -355,7 +355,8 @@ def multi_label_head(n_classes,
string type and have any value in `label_vocabulary`. Also there will be
errors if vocabulary is not provided and labels are string.
loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
- reduce training loss over batch. Defaults to `SUM`.
+ reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely
+ weighted sum of losses divided by batch size. See `tf.losses.Reduction`.
loss_fn: Optional loss function.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
@@ -404,7 +405,7 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access
weight_column=None,
thresholds=None,
label_vocabulary=None,
- loss_reduction=losses.Reduction.SUM,
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
loss_fn=None,
name=None):
self._n_classes = n_classes
diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py
index 776f0ee341..8837dfdc6c 100644
--- a/tensorflow/contrib/estimator/python/estimator/head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/head_test.py
@@ -272,9 +272,9 @@ class MultiLabelHead(test.TestCase):
logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
- # loss = labels * -log(sigmoid(logits)) +
- # (1 - labels) * -log(1 - sigmoid(logits))
- expected_training_loss = np.sum(
+ # loss = (labels * -log(sigmoid(logits)) +
+ # (1 - labels) * -log(1 - sigmoid(logits))) / 2
+ expected_training_loss = 0.5 * np.sum(
_sigmoid_cross_entropy(labels=labels, logits=logits))
actual_training_loss = head.create_loss(
features={'x': np.array(((42,),), dtype=np.int32)},
@@ -298,7 +298,7 @@ class MultiLabelHead(test.TestCase):
# For large logits, this is approximated as:
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits
- expected_training_loss = np.sum(
+ expected_training_loss = 0.5 * np.sum(
np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32))
actual_training_loss = head.create_loss(
features={'x': np.array(((42,),), dtype=np.int32)},
@@ -361,7 +361,7 @@ class MultiLabelHead(test.TestCase):
labels=labels_input)[0]
with self.test_session():
_initialize_variables(self, monitored_session.Scaffold())
- self.assertAllClose(np.sum(loss), actual_training_loss.eval())
+ self.assertAllClose(np.sum(loss) / 2., actual_training_loss.eval())
def test_eval_create_loss_loss_fn_wrong_shape(self):
"""Tests custom loss_fn that returns Tensor of unexpected shape."""
@@ -438,12 +438,13 @@ class MultiLabelHead(test.TestCase):
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
- # Sum over examples.
- expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits))
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels, logits=logits))
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
- keys.LOSS_MEAN: expected_loss / 2,
+ keys.LOSS_MEAN: expected_loss,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
@@ -468,14 +469,13 @@ class MultiLabelHead(test.TestCase):
labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
- # Sum over examples.
- expected_loss = (
- np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
- )
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
- keys.LOSS_MEAN: expected_loss / 2,
+ keys.LOSS_MEAN: expected_loss,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
@@ -533,14 +533,13 @@ class MultiLabelHead(test.TestCase):
labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
- # Sum over examples.
- expected_loss = (
- np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
- )
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
- keys.LOSS_MEAN: expected_loss / 2,
+ keys.LOSS_MEAN: expected_loss,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
@@ -562,15 +561,14 @@ class MultiLabelHead(test.TestCase):
labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
# loss = labels * -log(sigmoid(logits)) +
# (1 - labels) * -log(1 - sigmoid(logits))
- # Sum over examples.
- expected_loss = (
- np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits))
- )
+ # Sum over examples, divide by batch_size.
+ expected_loss = 0.5 * np.sum(
+ _sigmoid_cross_entropy(labels=labels, logits=logits))
keys = metric_keys.MetricKeys
expected_metrics = {
# Average loss over examples.
- keys.LOSS_MEAN: expected_loss / 2,
+ keys.LOSS_MEAN: expected_loss,
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.3333,
@@ -603,8 +601,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, weighted sum over examples.
- expected_loss = 25.
+ # Average over classes, weighted sum over examples, divide by batch_size.
+ # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2) / 2
+ expected_loss = 12.5
spec = head.create_estimator_spec(
features={
@@ -617,8 +616,8 @@ class MultiLabelHead(test.TestCase):
keys = metric_keys.MetricKeys
expected_metrics = {
- # Average loss over weighted examples.
- keys.LOSS_MEAN: expected_loss / 3,
+ # Average loss over weighted examples (denominator is sum(weights)).
+ keys.LOSS_MEAN: expected_loss * (2. / 3.),
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.2000,
@@ -663,7 +662,7 @@ class MultiLabelHead(test.TestCase):
# (1 - labels) * (logits > 0) * logits
expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]]
expected_weights = [[1.], [2.]]
- expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.
+ expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2.
training_loss, unreduced_loss, actual_weights, _ = head.create_loss(
features={
'x': np.array(((42,),), dtype=np.int32),
@@ -809,11 +808,8 @@ class MultiLabelHead(test.TestCase):
self.assertEqual(
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
- _assert_simple_summaries(self, {
- metric_keys.MetricKeys.LOSS: expected_loss,
- # Average loss over examples.
- metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
- }, summary_str, tol)
+ _assert_simple_summaries(
+ self, {metric_keys.MetricKeys.LOSS: expected_loss}, summary_str, tol)
def test_train(self):
head = head_lib.multi_label_head(n_classes=2)
@@ -823,8 +819,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # Average over classes, sum over examples, divide by batch_size.
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2
+ expected_loss = 8.75
self._test_train(
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
@@ -840,8 +837,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # Average over classes, sum over examples, divide by batch_size.
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2
+ expected_loss = 8.75
self._test_train(
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
@@ -858,8 +856,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # Average over classes, sum over examples, divide by batch_size.
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2
+ expected_loss = 8.75
self._test_train(
head=head, logits=logits, labels=labels, expected_loss=expected_loss)
@@ -871,8 +870,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # Average over classes, sum over examples, divide by batch_size.
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2
+ expected_loss = 8.75
expected_train_result = 'my_train_op'
class _Optimizer(object):
@@ -952,8 +952,9 @@ class MultiLabelHead(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, weighted sum over examples.
- expected_loss = 25.
+ # Average over classes, weighted sum over examples, divide by batch_size.
+ # loss = ( 1 * (10 + 10) / 2 + 2 * (15 + 0) / 2 ) / 2
+ expected_loss = 12.5
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
return string_ops.string_join(
@@ -987,11 +988,8 @@ class MultiLabelHead(test.TestCase):
self.assertEqual(
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
- _assert_simple_summaries(self, {
- metric_keys.MetricKeys.LOSS: expected_loss,
- # Average loss over weighted examples.
- metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3,
- }, summary_str, tol)
+ _assert_simple_summaries(
+ self, {metric_keys.MetricKeys.LOSS: expected_loss,}, summary_str, tol)
def test_multi_dim_weighted_train_create_loss(self):
"""Logits and labels of shape [2, 2, 3], weights [2, 2]."""
@@ -1008,8 +1006,8 @@ class MultiLabelHead(test.TestCase):
expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]]
# weights are reshaped to [2, 2, 1] to match logits.
expected_weights = [[[1.], [1.5]], [[2.], [2.5]]]
- # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
- expected_training_loss = 39.6667
+ # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167
+ expected_training_loss = 9.9167
training_loss, unreduced_loss, actual_weights, _ = head.create_loss(
features={'weights': weights},
mode=model_fn.ModeKeys.TRAIN,
@@ -1035,8 +1033,8 @@ class MultiLabelHead(test.TestCase):
weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
# loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
# = [[20/3, 10/3], [4, 8]]
- # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
- expected_loss = 39.6667
+ # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167
+ expected_loss = 9.9167
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
return string_ops.string_join(
@@ -1124,11 +1122,11 @@ class MultiLabelHead(test.TestCase):
weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
# loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
# = [[20/3, 10/3], [4, 8]]
- # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
- expected_loss = 39.6667
+ # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167
+ expected_loss = 9.9167
keys = metric_keys.MetricKeys
expected_metrics = {
- keys.LOSS_MEAN: expected_loss / np.sum(weights),
+ keys.LOSS_MEAN: expected_loss * (4. / np.sum(weights)),
# auc and auc_pr cannot be reliably calculated for only 4 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC: 0.4977,
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index 43cc157a1f..74d3d6d728 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -299,10 +299,11 @@ class MultiHeadTest(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75
# head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]
- # Average over classes, weighted sum over batch and heads.
- expected_loss_head1 = 17.5
- expected_loss_head2 = 30.0
+ # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15
+ expected_loss_head1 = 8.75
+ expected_loss_head2 = 15.
expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2
spec = multi_head.create_estimator_spec(
@@ -316,8 +317,8 @@ class MultiHeadTest(test.TestCase):
keys.LOSS + '/head1': expected_loss_head1,
keys.LOSS + '/head2': expected_loss_head2,
# Average loss over examples.
- keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2,
- keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2,
+ keys.LOSS_MEAN + '/head1': expected_loss_head1,
+ keys.LOSS_MEAN + '/head2': expected_loss_head2,
# auc and auc_pr cannot be reliably calculated for only 4-6 samples, but
# this assert tests that the algorithm remains consistent.
keys.AUC + '/head1': 0.1667,
@@ -363,8 +364,8 @@ class MultiHeadTest(test.TestCase):
tol = 1e-3
with self.test_session():
# Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2]
- # (averaged over classes, sum-reduced over examples).
- self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol)
+ # (averaged over classes, averaged over examples).
+ self.assertAllClose(8.75, loss.eval(), rtol=tol, atol=tol)
def test_train_create_loss_two_heads_with_weights(self):
# Use different example weighting for each head weighting.
@@ -399,18 +400,18 @@ class MultiHeadTest(test.TestCase):
with self.test_session():
# loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
# = [10, 7.5]
- # training_loss = 1 * 10 + 2 * 7.5 = 25
+ # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5
# head-weighted unreduced_loss = 1 * [10, 7.5]
self.assertAllClose(
[[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol)
# loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]
# = [20, 10]
- # training_loss = 2 * 20 + 3 * 10 = 70
+ # training_loss = (2 * 20 + 3 * 10) / 2 = 35
# head-weighted unreduced_loss = 2 * [20, 10]
self.assertAllClose(
[[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol)
- # head-weighted training_loss = 1 * 25 + 2 * 70 = 165
- self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol)
+ # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5
+ self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol)
# head-weighted example weights
self.assertAllClose(
[[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol)
@@ -447,18 +448,18 @@ class MultiHeadTest(test.TestCase):
with self.test_session():
# loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
# = [10, 7.5]
- # training_loss = 1 * 10 + 2 * 7.5 = 25
+ # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5
# head-weighted unreduced_loss = 1 * [10, 7.5]
self.assertAllClose(
[[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol)
# loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]
# = [20, 10]
- # training_loss = 2 * 20 + 3 * 10 = 70
+ # training_loss = (2 * 20 + 3 * 10) / 2 = 35
# head-weighted unreduced_loss = 2 * [20, 10]
self.assertAllClose(
[[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol)
- # head-weighted training_loss = 1 * 25 + 2 * 70 = 165
- self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol)
+ # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5
+ self.assertAllClose(82.5, training_loss.eval(), rtol=tol, atol=tol)
# head-weighted example weights
self.assertAllClose(
[[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol)
@@ -511,8 +512,8 @@ class MultiHeadTest(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75
+ expected_loss = 8.75
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
return string_ops.string_join(
@@ -546,8 +547,6 @@ class MultiHeadTest(test.TestCase):
_assert_simple_summaries(self, {
metric_keys.MetricKeys.LOSS: expected_loss,
metric_keys.MetricKeys.LOSS + '/head1': expected_loss,
- # Average loss over examples.
- metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2,
}, summary_str, tol)
def test_train_one_head_with_optimizer(self):
@@ -560,8 +559,8 @@ class MultiHeadTest(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# expected_unweighted_loss = [[10., 10.], [15., 0.]]
- # Average over classes, sum over weights.
- expected_loss = 17.5
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75
+ expected_loss = 8.75
expected_train_result = 'my_train_op'
class _Optimizer(object):
@@ -607,10 +606,12 @@ class MultiHeadTest(test.TestCase):
# loss = labels * (logits < 0) * (-logits) +
# (1 - labels) * (logits > 0) * logits =>
# head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]
+ # loss = ( (10 + 10) / 2 + (15 + 0) / 2 ) / 2 = 8.75
# head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]
+ # loss = ( (20 + 20 + 20) / 3 + (30 + 0 + 0) / 3 ) / 2 = 15
# Average over classes, weighted sum over batch and heads.
- expected_loss_head1 = 17.5
- expected_loss_head2 = 30.0
+ expected_loss_head1 = 8.75
+ expected_loss_head2 = 15.0
expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2
expected_train_result = 'my_train_op'
def _train_op_fn(loss):
@@ -646,9 +647,6 @@ class MultiHeadTest(test.TestCase):
metric_keys.MetricKeys.LOSS: expected_loss,
metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1,
metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2,
- # Average loss over examples.
- metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss_head1 / 2,
- metric_keys.MetricKeys.LOSS_MEAN + '/head2': expected_loss_head2 / 2,
}, summary_str, tol)