diff options
author | 2018-02-01 13:33:20 -0800 | |
---|---|---|
committer | 2018-02-01 17:27:10 -0800 | |
commit | 7092e612c1ec51b4aeafe9201706331dd4c3199e (patch) | |
tree | 8535fa9e169b31e48bc0343346d26a502c1be22a | |
parent | c460a245a25467a66d7319544afb92407057b424 (diff) |
Fixes a type conversion bug in losses.compute_weighted_loss for reduction=SUM_OVER_BATCH_SIZE.
PiperOrigin-RevId: 184186573
-rw-r--r-- | tensorflow/python/kernel_tests/losses_test.py | 28 | ||||
-rw-r--r-- | tensorflow/python/ops/losses/losses_impl.py | 2 |
2 files changed, 29 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 81af3a0887..00c6706593 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -1345,6 +1345,34 @@ class ComputeWeightedLossTest(test.TestCase): self.assertAllClose( np.mean(self._raw_losses), unweighted_loss.eval()) + def testUnweightedFromPlaceholder(self): + for reduction in losses.Reduction.all(): + with ops.Graph().as_default() as g: + self.assertEqual(0, len(util.get_losses())) + raw_losses = array_ops.placeholder(dtype=dtypes.float32) + feed_dict = {raw_losses: self._raw_losses} + unweighted_losses = ( + losses.compute_weighted_loss(raw_losses, reduction=reduction), + losses.compute_weighted_loss( + raw_losses, weights=np.ones((1, 1, 1)), reduction=reduction), + losses.compute_weighted_loss( + raw_losses, weights=np.ones((1, 1, 4)), reduction=reduction), + ) + self.assertEqual(3, len(util.get_losses())) + with self.test_session(g): + for unweighted_loss in unweighted_losses: + if reduction == losses.Reduction.NONE: + self.assertAllClose( + self._raw_losses, unweighted_loss.eval(feed_dict)) + elif reduction == losses.Reduction.SUM: + self.assertAllClose( + np.sum(self._raw_losses), unweighted_loss.eval(feed_dict)) + else: + # reduction one of MEAN, SUM_OVER_NONZERO_WEIGHTS, + # SUM_BY_NONZERO_WEIGHTS or SUM_OVER_BATCH_SIZE. + self.assertAllClose( + np.mean(self._raw_losses), unweighted_loss.eval(feed_dict)) + def testScalarWeight(self): with ops.Graph().as_default(): self.assertEqual(0, len(util.get_losses())) diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 73563486e1..e75a9b22e4 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -151,7 +151,7 @@ def _num_present(losses, weights, per_batch=False): def _num_elements(losses): """Computes the number of elements in `losses` tensor.""" with ops.name_scope(None, "num_elements", values=[losses]) as scope: - return array_ops.size(losses, name=scope, out_type=losses.dtype) + return math_ops.cast(array_ops.size(losses, name=scope), dtype=losses.dtype) @tf_export("losses.compute_weighted_loss") |