aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-01 13:33:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-01 17:27:10 -0800
commit7092e612c1ec51b4aeafe9201706331dd4c3199e (patch)
tree8535fa9e169b31e48bc0343346d26a502c1be22a
parentc460a245a25467a66d7319544afb92407057b424 (diff)
Fixes a type conversion bug in losses.compute_weighted_loss for reduction=SUM_OVER_BATCH_SIZE.
PiperOrigin-RevId: 184186573
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py28
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py2
2 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
index 81af3a0887..00c6706593 100644
--- a/tensorflow/python/kernel_tests/losses_test.py
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -1345,6 +1345,34 @@ class ComputeWeightedLossTest(test.TestCase):
self.assertAllClose(
np.mean(self._raw_losses), unweighted_loss.eval())
+ def testUnweightedFromPlaceholder(self):
+ for reduction in losses.Reduction.all():
+ with ops.Graph().as_default() as g:
+ self.assertEqual(0, len(util.get_losses()))
+ raw_losses = array_ops.placeholder(dtype=dtypes.float32)
+ feed_dict = {raw_losses: self._raw_losses}
+ unweighted_losses = (
+ losses.compute_weighted_loss(raw_losses, reduction=reduction),
+ losses.compute_weighted_loss(
+ raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction),
+ losses.compute_weighted_loss(
+ raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction),
+ )
+ self.assertEqual(3, len(util.get_losses()))
+ with self.test_session(g):
+ for unweighted_loss in unweighted_losses:
+ if reduction == losses.Reduction.NONE:
+ self.assertAllClose(
+ self._raw_losses, unweighted_loss.eval(feed_dict))
+ elif reduction == losses.Reduction.SUM:
+ self.assertAllClose(
+ np.sum(self._raw_losses), unweighted_loss.eval(feed_dict))
+ else:
+ # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS,
+ # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE.
+ self.assertAllClose(
+ np.mean(self._raw_losses), unweighted_loss.eval(feed_dict))
+
def testScalarWeight(self):
with ops.Graph().as_default():
self.assertEqual(0, len(util.get_losses()))
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 73563486e1..e75a9b22e4 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -151,7 +151,7 @@ def _num_present(losses, weights, per_batch=False):
def _num_elements(losses):
"""Computes the number of elements in `losses` tensor."""
with ops.name_scope(None, "num_elements", values=[losses]) as scope:
- return array_ops.size(losses, name=scope, out_type=losses.dtype)
+ return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype)
@tf_export("losses.compute_weighted_loss")