diff options
author | 2016-12-02 08:56:45 -0800 | |
---|---|---|
committer | 2016-12-02 09:06:39 -0800 | |
commit | 896285a8dca7bddbf328b3728683acf619f26c13 (patch) | |
tree | 863dcd57b331cfba86f2d73a37f7e90b1d4dede5 | |
parent | d8b037828ab66f25a7b526848cdc3fa9b3b9f198 (diff) |
Moves tf.contrib.losses into core, with changes.
Change: 140855283
-rw-r--r-- | tensorflow/contrib/losses/python/losses/loss_ops.py | 16 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/__init__.py | 10 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/losses_test.py | 1142 | ||||
-rw-r--r-- | tensorflow/python/ops/losses/BUILD | 39 | ||||
-rw-r--r-- | tensorflow/python/ops/losses/__init__.py | 21 | ||||
-rw-r--r-- | tensorflow/python/ops/losses/losses.py | 588 | ||||
-rw-r--r-- | tensorflow/python/ops/losses/util.py | 88 |
9 files changed, 1907 insertions, 5 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index c17b251d3e..780def4269 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops - +from tensorflow.python.util.deprecation import deprecated __all__ = ["absolute_difference", "add_loss", @@ -141,6 +141,7 @@ def _safe_mean(losses, num_present): return _safe_div(total_loss, num_present) +@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.") @deprecated_args( "2016-11-25", "`weight` is being deprecated, use `weights`.", "weight") def compute_weighted_loss( @@ -235,6 +236,7 @@ def _num_present(losses, weights, per_batch=False): return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) +@deprecated("2016-12-30", "Use tf.losses.add_loss instead.") @add_arg_scope def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): """Adds a externally defined loss to the collection of losses. @@ -247,6 +249,7 @@ def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): ops.add_to_collection(loss_collection, loss) +@deprecated("2016-12-30", "Use tf.losses.get_losses instead.") def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): """Gets the list of losses from the loss_collection. @@ -260,6 +263,7 @@ def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): return ops.get_collection(loss_collection, scope) +@deprecated("2016-12-30", "Use tf.losses.get_regularization_losses instead.") def get_regularization_losses(scope=None): """Gets the regularization losses. @@ -272,6 +276,7 @@ def get_regularization_losses(scope=None): return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope) +@deprecated("2016-12-30", "Use tf.losses.get_total_loss instead.") def get_total_loss(add_regularization_losses=True, name="total_loss"): """Returns a tensor whose value represents the total loss. @@ -294,6 +299,7 @@ def get_total_loss(add_regularization_losses=True, name="total_loss"): return math_ops.add_n(losses, name=name) +@deprecated("2016-12-30", "Use tf.losses.absolute_difference instead.") @deprecated_args( "2016-11-25", "`targets` is being deprecated, use `labels`." @@ -339,6 +345,7 @@ def absolute_difference( return compute_weighted_loss(losses, weights, scope=scope) +@deprecated("2016-12-30", "Use tf.losses.sigmoid_cross_entropy instead.") @deprecated_args( "2016-11-25", "`weight` is being deprecated, use `weights`", "weight") def sigmoid_cross_entropy( @@ -389,6 +396,7 @@ def sigmoid_cross_entropy( return compute_weighted_loss(losses, weights, scope=scope) +@deprecated("2016-12-30", "Use tf.losses.softmax_cross_entropy instead.") @deprecated_args( "2016-11-25", "`weight` is being deprecated, use `weights`", "weight") def softmax_cross_entropy( @@ -440,6 +448,7 @@ def softmax_cross_entropy( return compute_weighted_loss(losses, weights, scope=scope) +@deprecated("2016-12-30", "Use tf.losses.sparse_softmax_cross_entropy instead.") @deprecated_args( "2016-11-25", "`weight` is being deprecated, use `weights`", "weight") def sparse_softmax_cross_entropy( @@ -479,6 +488,7 @@ def sparse_softmax_cross_entropy( return compute_weighted_loss(losses, weights, scope=scope) +@deprecated("2016-12-30", "Use tf.losses.log_loss instead.") @deprecated_args( "2016-11-25", "`targets` is being deprecated, use `labels`." @@ -528,6 +538,7 @@ def log_loss( return compute_weighted_loss(losses, weights, scope=scope) +@deprecated("2016-12-30", "Use tf.losses.hinge_loss instead.") @deprecated_args( "2016-11-25", "`target` is being deprecated, use `labels`.", "target") def hinge_loss(logits, labels=None, scope=None, target=None): @@ -557,6 +568,7 @@ def hinge_loss(logits, labels=None, scope=None, target=None): return nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits))) +@deprecated("2016-12-30", "Use tf.losses.mean_squared_error instead.") @deprecated_args( "2016-11-25", "`targets` is being deprecated, use `labels`." @@ -602,6 +614,7 @@ def mean_squared_error( return compute_weighted_loss(losses, weights, scope=scope) +@deprecated("2016-12-30", "Use tf.losses.mean_pairwise_squared_error instead.") @deprecated_args( "2016-11-25", "`targets` is being deprecated, use `labels`." @@ -691,6 +704,7 @@ def mean_pairwise_squared_error( return mean_loss +@deprecated("2016-12-30", "Use tf.losses.cosine_distance instead.") @deprecated_args( "2016-11-25", "`targets` is being deprecated, use `labels`." diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 180aa291ec..c5c2f77378 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -43,6 +43,7 @@ py_library( ":training", ":ops", ":test_ops", + "//tensorflow/python/ops/losses", "//tensorflow/python/debug:debug_py", ] + if_not_windows([ "//tensorflow/contrib:contrib_py", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index cd2f4fa328..e323c9b6a4 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -87,6 +87,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import resources from tensorflow.python.ops import sdca_ops as sdca from tensorflow.python.ops import image_ops as image +from tensorflow.python.ops import losses from tensorflow.python.ops import sets from tensorflow.python.user_ops import user_ops from tensorflow.python.util import compat @@ -218,6 +219,7 @@ _allowed_symbols.extend([ 'graph_util', 'image', 'logging', + 'losses', 'newaxis', 'nn', 'python_io', @@ -245,10 +247,10 @@ _allowed_symbols.extend([ remove_undocumented(__name__, _allowed_symbols, [framework_lib, array_ops, client_lib, check_ops, compat, constant_op, control_flow_ops, functional_ops, - histogram_ops, io_ops, math_ops, nn, resource_loader, - resources, sets, script_ops, session_ops, sparse_ops, - state_ops, string_ops, summary, tensor_array_ops, train, - layers]) + histogram_ops, io_ops, losses, math_ops, nn, + resource_loader, resources, sets, script_ops, session_ops, + sparse_ops, state_ops, string_ops, summary, + tensor_array_ops, train, layers]) # Special dunders that we choose to export: _exported_dunders = set([ diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index fccb225a2b..3185b1fd06 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -194,6 +194,13 @@ tf_py_test( ) tf_py_test( + name = "losses_test", + size = "small", + srcs = ["losses_test.py"], + additional_deps = ["//tensorflow:tensorflow_py"], +) + +tf_py_test( name = "matrix_inverse_op_test", size = "small", srcs = ["matrix_inverse_op_test.py"], diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py new file mode 100644 index 0000000000..2393124ba3 --- /dev/null +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -0,0 +1,1142 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for losses.""" +# pylint: disable=unused-import,g-bad-import-order +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +# pylint: enable=unused-import + +import numpy as np +import tensorflow as tf + + +class AbsoluteDifferenceLossTest(tf.test.TestCase): + + def setUp(self): + self._predictions = tf.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) + self._labels = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + + def testValueErrorThrownWhenWeightIsNone(self): + with self.test_session(): + with self.assertRaises(ValueError): + tf.losses.absolute_difference( + self._predictions, self._predictions, weights=None) + + def testAllCorrectNoLossWeight(self): + loss = tf.losses.absolute_difference( + self._predictions, self._predictions) + with self.test_session(): + self.assertAlmostEqual(0.0, loss.eval(), 3) + + def testNonZeroLoss(self): + loss = tf.losses.absolute_difference( + self._labels, self._predictions) + with self.test_session(): + self.assertAlmostEqual(5.5, loss.eval(), 3) + + def testNonZeroLossWithPythonScalarWeight(self): + weights = 2.3 + loss = tf.losses.absolute_difference( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(5.5 * weights, loss.eval(), 3) + + def testNonZeroLossWithScalarTensorWeight(self): + weights = 2.3 + loss = tf.losses.absolute_difference( + self._labels, self._predictions, tf.constant(weights)) + with self.test_session(): + self.assertAlmostEqual(5.5 * weights, loss.eval(), 3) + + def testNonZeroLossWithOneDimBatchSpecificWeights(self): + weights = tf.constant([1.2, 0.0], shape=[2,]) + loss = tf.losses.absolute_difference( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(5.6, loss.eval(), 3) + + def testNonZeroLossWithTwoDimBatchSpecificWeights(self): + weights = tf.constant([1.2, 0.0], shape=[2, 1]) + loss = tf.losses.absolute_difference( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(5.6, loss.eval(), 3) + + def testNonZeroLossWithSampleSpecificWeights(self): + weights = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) + loss = tf.losses.absolute_difference( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(16.6, loss.eval(), 3) + + def testNonZeroLossWithSampleSpecificWeightsMostZero(self): + weights = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) + loss = tf.losses.absolute_difference( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(6.0, loss.eval(), 3) + + def testLossWithSampleSpecificWeightsAllZero(self): + weights = tf.zeros((2, 3)) + loss = tf.losses.absolute_difference( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(0.0, loss.eval(), 3) + + +class SoftmaxCrossEntropyLossTest(tf.test.TestCase): + + def testNoneWeightRaisesValueError(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + with self.test_session(): + with self.assertRaises(ValueError): + tf.losses.softmax_cross_entropy(labels, logits, weights=None) + + def testAllCorrect(self): + with self.test_session(): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + loss = tf.losses.softmax_cross_entropy(labels, logits) + self.assertEquals('softmax_cross_entropy_loss/value', loss.op.name) + self.assertAlmostEqual(loss.eval(), 0.0, 3) + + def testAllWrong(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + + with self.test_session(): + loss = tf.losses.softmax_cross_entropy(labels, logits) + self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value') + self.assertAlmostEqual(loss.eval(), 10.0, 3) + + def testNonZeroLossWithPythonScalarWeight(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + weights = 2.3 + with self.test_session(): + loss = tf.losses.softmax_cross_entropy(labels, logits, weights) + self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) + + def testNonZeroLossWithScalarTensorWeight(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + weights = 2.3 + with self.test_session(): + loss = tf.losses.softmax_cross_entropy( + labels, logits, tf.constant(weights)) + self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) + + def testNonZeroLossWithOneDimBatchSpecificWeights(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + weights = tf.constant([1.2, 3.4, 5.6], shape=[3]) + with self.test_session(): + loss = tf.losses.softmax_cross_entropy(labels, logits, weights) + self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3) + + def testAllWrongAllWeightsMissing(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + weights = tf.constant([0, 0, 0], shape=[3]) + with self.test_session(): + loss = tf.losses.softmax_cross_entropy(labels, logits, weights) + self.assertAlmostEqual(0.0, loss.eval(), 3) + + def testSomeWeightsMissing(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + weights = tf.constant([1.2, 0, 0], shape=[3]) + with self.test_session(): + loss = tf.losses.softmax_cross_entropy(labels, logits, weights) + self.assertAlmostEqual(12.0, loss.eval(), 3) + + def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self): + with self.test_session(): + logits = tf.constant([[100.0, -100.0, -100.0], + [-100.0, 100.0, -100.0], + [-100.0, -100.0, 100.0]]) + labels = tf.constant([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + weights = tf.constant([[3, 4, 5], + [2, 6, 0], + [8, 0, 1]]) + + with self.assertRaises(ValueError): + tf.losses.softmax_cross_entropy( + labels, logits, weights=weights).eval() + + def testSoftmaxLabelSmoothing(self): + with self.test_session(): + # Softmax Cross Entropy Loss is: + # -\sum_i p_i \log q_i + # where for a softmax activation + # \log q_i = x_i - \log \sum_j \exp x_j + # = x_i - x_max - \log \sum_j \exp (x_j - x_max) + # For our activations, [100, -100, -100] the log partion function becomes + # \log ( exp(0) + exp(-200) + exp(-200) ) = 0 + # so our log softmaxes become: [0, -200, -200] + # so our cross entropy loss is: + # -(1 - L + L/n) * 0 + 400 * L/n = 400 L/n + logits = tf.constant([[100.0, -100.0, -100.0]]) + labels = tf.constant([[1, 0, 0]]) + label_smoothing = 0.1 + loss = tf.losses.softmax_cross_entropy( + labels, logits, label_smoothing=label_smoothing) + self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value') + expected_value = 400.0 * label_smoothing / 3.0 + self.assertAlmostEqual(loss.eval(), expected_value, 3) + + +class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase): + + def testNoneWeightRaisesValueError(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0], [1], [2]]) + with self.test_session(): + with self.assertRaises(ValueError): + tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights=None) + + def testAllCorrectInt32Labels(self): + with self.test_session(): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0], [1], [2]], dtype=tf.int32) + loss = tf.losses.sparse_softmax_cross_entropy(labels, logits) + self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') + self.assertAlmostEqual(loss.eval(), 0.0, 3) + + def testAllCorrectInt64Labels(self): + with self.test_session(): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[0], [1], [2]], dtype=tf.int64) + loss = tf.losses.sparse_softmax_cross_entropy(labels, logits) + self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') + self.assertAlmostEqual(loss.eval(), 0.0, 3) + + def testAllCorrectNonColumnLabels(self): + with self.test_session(): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([0, 1, 2]) + loss = tf.losses.sparse_softmax_cross_entropy(labels, logits) + self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') + self.assertAlmostEqual(loss.eval(), 0.0, 3) + + def testAllWrongInt32Labels(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[2], [0], [1]], dtype=tf.int32) + + with self.test_session(): + loss = tf.losses.sparse_softmax_cross_entropy(labels, logits) + self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') + self.assertAlmostEqual(loss.eval(), 10.0, 3) + + def testAllWrongInt64Labels(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[2], [0], [1]], dtype=tf.int64) + + with self.test_session(): + loss = tf.losses.sparse_softmax_cross_entropy(labels, logits) + self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') + self.assertAlmostEqual(loss.eval(), 10.0, 3) + + def testAllWrongNonColumnLabels(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([2, 0, 1]) + + with self.test_session(): + loss = tf.losses.sparse_softmax_cross_entropy(labels, logits) + self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value') + self.assertAlmostEqual(loss.eval(), 10.0, 3) + + def testNonZeroLossWithPythonScalarWeight(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[2], [0], [1]]) + weights = 2.3 + with self.test_session(): + loss = tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights) + self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) + + def testNonZeroLossWithScalarTensorWeight(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[2], [0], [1]]) + weights = 2.3 + with self.test_session(): + loss = tf.losses.sparse_softmax_cross_entropy( + labels, logits, tf.constant(weights)) + self.assertAlmostEqual(weights * 10.0, loss.eval(), 3) + + def testNonZeroLossWithOneDimBatchSpecificWeights(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[2], [0], [1]]) + weights = tf.constant([1.2, 3.4, 5.6], shape=[3]) + with self.test_session(): + loss = tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights) + self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3) + + def testNonZeroLossWithColumnWeights(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[2], [0], [1]]) + weights = tf.constant([[1.2], [3.4], [5.6]]) + with self.test_session(): + loss = tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights) + self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3) + + def testAllWrongAllWeightsMissing(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[2], [0], [1]]) + weights = tf.constant([0, 0, 0], shape=[3]) + with self.test_session(): + loss = tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights) + self.assertAlmostEqual(0.0, loss.eval(), 3) + + def testSomeWeightsMissing(self): + logits = tf.constant([[10.0, 0.0, 0.0], + [0.0, 10.0, 0.0], + [0.0, 0.0, 10.0]]) + labels = tf.constant([[2], [0], [1]]) + weights = tf.constant([1.2, 0, 0], shape=[3]) + with self.test_session(): + loss = tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights) + self.assertAlmostEqual(12.0, loss.eval(), 3) + + def testMeasurementSpecificWeightsRaisesException(self): + with self.test_session(): + logits = tf.constant([[100.0, -100.0, -100.0], + [-100.0, 100.0, -100.0], + [-100.0, -100.0, 100.0]]) + labels = tf.constant([[0], [1], [2]]) + weights = tf.constant([[3, 4, 5], + [2, 6, 0], + [8, 0, 1]]) + + with self.assertRaises(ValueError): + tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights=weights).eval() + + def testInconsistentWeightSizeRaisesException(self): + """The weight tensor has incorrect number of elements.""" + with self.test_session(): + logits = tf.constant([[100.0, -100.0, -100.0], + [-100.0, 100.0, -100.0], + [-100.0, -100.0, 100.0]]) + labels = tf.constant([[0], [1], [2]]) + weights = tf.constant([1.2, 3.4, 5.6, 7.8]) + + with self.assertRaises(ValueError): + tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights=weights).eval() + + def testInconsistentLabelSizeRaisesException(self): + """The label tensor has incorrect number of elements.""" + with self.test_session(): + logits = tf.constant([[100.0, -100.0, -100.0], + [-100.0, 100.0, -100.0], + [-100.0, -100.0, 100.0]]) + labels = tf.constant([[0], [1], [2], [3]]) + weights = tf.constant([1.2, 3.4, 5.6]) + + with self.assertRaises(ValueError): + tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights=weights).eval() + + def testInconsistentWeightShapeRaisesException(self): + """The weight tensor has incorrect shape.""" + with self.test_session(): + logits = tf.constant([[100.0, -100.0, -100.0, -100.0], + [-100.0, 100.0, -100.0, -100.0], + [-100.0, -100.0, 100.0, -100.0], + [-100.0, -100.0, -100.0, 100.0]]) + labels = tf.constant([[0], [1], [2], [3]]) + weights = tf.constant([[1.2, 3.4], [5.6, 7.8]]) + + with self.assertRaises(ValueError): + tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights=weights).eval() + + def testInconsistentLabelShapeRaisesException(self): + """The label tensor has incorrect shape.""" + with self.test_session(): + logits = tf.constant([[100.0, -100.0, -100.0, -100.0], + [-100.0, 100.0, -100.0, -100.0], + [-100.0, -100.0, 100.0, -100.0], + [-100.0, -100.0, -100.0, 100.0]]) + labels = tf.constant([[0, 1], [2, 3]]) + weights = tf.constant([1.2, 3.4, 5.6, 7.8]) + + with self.assertRaises(tf.errors.InvalidArgumentError): + tf.losses.sparse_softmax_cross_entropy( + labels, logits, weights=weights).eval() + + +class SigmoidCrossEntropyLossTest(tf.test.TestCase): + + def testAllCorrectSigmoid(self): + with self.test_session(): + logits = tf.constant([[100.0, -100.0, -100.0], + [-100.0, 100.0, -100.0], + [-100.0, -100.0, 100.0]]) + labels = tf.constant([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + loss = tf.losses.sigmoid_cross_entropy(labels, logits) + self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') + self.assertAlmostEqual(0.0, loss.eval(), 3) + + def testLossWithSingleDimPlaceholderForLogitsAndWeights1(self): + logits = tf.placeholder(tf.float32, shape=(None, 1)) + labels = tf.placeholder(tf.float32, shape=(None, 1)) + weights = tf.ones_like(logits, dtype=tf.float32) + + loss = tf.losses.sigmoid_cross_entropy(labels, logits, weights) + + with self.test_session() as sess: + loss = sess.run(loss, feed_dict={ + logits: np.ones((32, 1)), + labels: np.ones((32, 1)), + }) + self.assertAlmostEqual(0.313, loss, 3) + + def testLossWithSingleDimPlaceholderForLogitsAndWeights2(self): + logits = tf.placeholder(tf.float32, shape=(None, 2)) + labels = tf.placeholder(tf.float32, shape=(None, 2)) + weights = tf.ones_like(logits, dtype=tf.float32) + + loss = tf.losses.sigmoid_cross_entropy(labels, logits, weights) + + with self.test_session() as sess: + loss = sess.run(loss, feed_dict={ + logits: np.ones((32, 2)), + labels: np.ones((32, 2)), + }) + self.assertAlmostEqual(0.313, loss, 3) + + def testAllWrongSigmoid(self): + with self.test_session(): + logits = tf.constant([[100.0, -100.0, -100.0], + [-100.0, 100.0, -100.0], + [-100.0, -100.0, 100.0]]) + labels = tf.constant([[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + loss = tf.losses.sigmoid_cross_entropy(labels, logits) + self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') + self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3) + + def testAllWrongSigmoidWithMeasurementSpecificWeights(self): + with self.test_session(): + logits = tf.constant([[100.0, -100.0, -100.0], + [-100.0, 100.0, -100.0], + [-100.0, -100.0, 100.0]]) + labels = tf.constant([[0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + weights = tf.constant([[3, 4, 5], + [2, 6, 0], + [8, 0, 1]]) + loss = tf.losses.sigmoid_cross_entropy( + labels, logits, weights) + self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') + self.assertAlmostEqual(1700.0 / 7.0, loss.eval(), 3) + + def testMultiCorrectSigmoid(self): + logits = tf.constant([[100.0, -100.0, 100.0], + [100.0, 100.0, -100.0], + [-100.0, 100.0, 100.0]]) + labels = tf.constant([[1, 0, 1], + [1, 1, 0], + [0, 1, 1]]) + loss = tf.losses.sigmoid_cross_entropy(labels, logits) + self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') + + with self.test_session(): + self.assertAlmostEqual(loss.eval(), 0.0, 3) + + def testSigmoidLabelSmoothingCorrect(self): + with self.test_session(): + logits = tf.constant([[100.0, -100.0, -100.0]]) + labels = tf.constant([[1, 0, 1]]) + # Sigmoid cross entropy loss is: + # max(x,0) - x*z + log(1 + exp(-abs(x))) + # The new labels are: + # z' = z * (1 - L) + 0.5 L + # 1 -> 1 - 0.5 L + # 0 -> 0.5 L + # here we expect: + # 1/3 * (100 - 100 * (1 - 0.5 L) + 0 + # + 0 + 100 * (0.5 L) + 0 + # + 0 + 100 * (1 - 0.5 L) + 0) + # = 1/3 * (100 + 50 L) + label_smoothing = 0.1 + loss = tf.losses.sigmoid_cross_entropy( + labels, logits, label_smoothing=label_smoothing) + self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value') + expected_value = (100.0 + 50.0 * label_smoothing) / 3.0 + self.assertAlmostEqual(loss.eval(), expected_value, 3) + + def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self): + with self.test_session(): + label_smoothing = 0.1 + sigmoid_logits = tf.constant([[100.0, -100.0, -100.0]]) + sigmoid_labels = tf.constant([[1, 0, 1]]) + sigmoid_loss = tf.losses.sigmoid_cross_entropy( + sigmoid_labels, sigmoid_logits, label_smoothing=label_smoothing) + + softmax_logits = tf.constant([[0.0, 100.0], [100.0, 0.0], [100.0, 0.0]]) + softmax_labels = tf.constant([[0, 1], [1, 0], [0, 1]]) + softmax_loss = tf.losses.softmax_cross_entropy( + softmax_labels, softmax_logits, label_smoothing=label_smoothing) + self.assertAlmostEqual(sigmoid_loss.eval(), softmax_loss.eval(), 3) + + +class LogLossTest(tf.test.TestCase): + + def setUp(self): + predictions = np.asarray([.9, .2, .2, .8, .4, .6]).reshape((2, 3)) + labels = np.asarray([1.0, 0.0, 1.0, 1.0, 0.0, 0.0]).reshape((2, 3)) + + self._np_predictions = predictions + self._np_labels = labels + + epsilon = 1e-7 + self._expected_losses = np.multiply( + labels, np.log(predictions + epsilon)) + np.multiply( + 1 - labels, np.log(1 - predictions + epsilon)) + + self._predictions = tf.constant(predictions) + self._labels = tf.constant(labels) + + def testValueErrorThrownWhenWeightIsNone(self): + with self.test_session(): + with self.assertRaises(ValueError): + tf.losses.log_loss(self._labels, self._labels, weights=None) + + def testAllCorrectNoLossWeight(self): + loss = tf.losses.log_loss(self._labels, self._labels) + with self.test_session(): + self.assertAlmostEqual(0.0, loss.eval(), 3) + + def testAllCorrectNoLossWeightWithPlaceholder(self): + tf_predictions = tf.placeholder(tf.float32, shape=self._np_labels.shape) + loss = tf.losses.log_loss(self._labels, tf_predictions) + with self.test_session(): + self.assertAlmostEqual(0.0, loss.eval(feed_dict={ + tf_predictions: self._np_labels}), 3) + + def testNonZeroLoss(self): + loss = tf.losses.log_loss(self._labels, self._predictions) + with self.test_session(): + self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0, + loss.eval(), 3) + + def testNonZeroLossWithPythonScalarWeight(self): + weights = 2.3 + loss = tf.losses.log_loss( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, + loss.eval(), 3) + + def testNonZeroLossWithScalarTensorWeight(self): + weights = 2.3 + loss = tf.losses.log_loss( + self._labels, self._predictions, tf.constant(weights)) + with self.test_session(): + self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, + loss.eval(), 3) + + def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self): + tf_predictions = tf.placeholder(tf.float32, + shape=self._np_predictions.shape) + weights = 2.3 + loss = tf.losses.log_loss( + self._labels, tf_predictions, tf.constant(weights)) + with self.test_session() as sess: + loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) + self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, + loss, 3) + + def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self): + tf_predictions = tf.placeholder(tf.float32, shape=[None, None]) + weights = 2.3 + loss = tf.losses.log_loss( + self._labels, tf_predictions, tf.constant(weights)) + with self.test_session() as sess: + loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) + self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0, + loss, 3) + + def testNonZeroLossWithOneDimBatchSpecificWeights(self): + weights = tf.constant([1.2, 3.4], shape=[2]) + expected_losses = np.multiply( + self._expected_losses, + np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3))) + loss = tf.losses.log_loss( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, + loss.eval(), 3) + + def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self): + weights = tf.constant([1.2, 0], shape=[2]) + expected_losses = np.multiply( + self._expected_losses, + np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape((2, 3))) + loss = tf.losses.log_loss( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, + loss.eval(), 3) + + def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self): + weights = tf.constant([1.2, 0], shape=[2, 1]) + expected_losses = np.multiply( + self._expected_losses, + np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape((2, 3))) + loss = tf.losses.log_loss( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, + loss.eval(), 3) + + def testWeightsWithSameNumDimsButWrongShapeThrowsException(self): + weights = tf.constant(np.random.normal(size=(2, 4)), shape=[2, 4]) + with self.test_session(): + with self.assertRaises(ValueError): + tf.losses.log_loss(self._labels, self._predictions, weights) + + def testNonZeroLossWithMeasurementSpecificWeights(self): + weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) + expected_losses = np.multiply(self._expected_losses, weights) + + loss = tf.losses.log_loss( + self._labels, + self._predictions, + tf.constant(weights, shape=(2, 3))) + with self.test_session(): + self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3) + + def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self): + weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3)) + expected_losses = np.multiply(self._expected_losses, weights) + + tf_predictions = tf.placeholder(tf.float32, shape=[2, 3]) + loss = tf.losses.log_loss( + self._labels, + tf_predictions, + tf.constant(weights, shape=(2, 3))) + + with self.test_session() as sess: + loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) + self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3) + + def testNonZeroLossWithSampleSpecificWeightsMostZero(self): + weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3)) + expected_losses = np.multiply(self._expected_losses, weights) + + loss = tf.losses.log_loss( + self._labels, + self._predictions, + tf.constant(weights, shape=(2, 3))) + with self.test_session(): + self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3) + + def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self): + weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3)) + expected_losses = np.multiply(self._expected_losses, weights) + + tf_predictions = tf.placeholder(tf.float32, shape=[2, 3]) + tf_weights = tf.constant(weights, shape=(2, 3)) + loss = tf.losses.log_loss(self._labels, tf_predictions, tf_weights) + + with self.test_session() as sess: + loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions}) + self.assertAlmostEqual(-np.sum(expected_losses), loss, 3) + + def testLossWithSampleSpecificWeightsAllZero(self): + tf_weights = tf.zeros(shape=(2, 3)) + loss = tf.losses.log_loss( + self._labels, self._predictions, tf_weights) + with self.test_session(): + self.assertAlmostEqual(0.0, loss.eval(), 3) + + +class HingeLossTest(tf.test.TestCase): + + def testIncompatibleShapes(self): + with self.test_session(): + logits = tf.constant([[-1.0], [2.1]]) + labels = tf.constant([0.0, 1.0]) + with self.assertRaises(ValueError): + _ = tf.losses.hinge_loss(labels, logits).eval() + + def testAllOutsideMargin(self): + with self.test_session(): + logits = tf.constant([1.2, -1.4, -1.0, 2.1]) + labels = tf.constant([1.0, 0.0, 0.0, 1.0]) + loss = tf.losses.hinge_loss(labels, logits) + self.assertAllClose(loss.eval(), 0.0, atol=1e-3) + + def testSomeInsideMargin(self): + with self.test_session(): + logits = tf.constant([[-0.7], [-1.4], [1.4], [0.6]]) + labels = tf.constant([[0.0], [0.0], [1.0], [1.0]]) + loss = tf.losses.hinge_loss(labels, logits) + # Examples 1 and 4 are on the correct side of the hyperplane but within + # the margin so they incur some (small) loss. + self.assertAllClose(loss.eval(), 0.175, atol=1e-3) + + def testSomeMisclassified(self): + with self.test_session(): + logits = tf.constant([[[1.2], [0.4], [-1.0], [-1.1]]]) + labels = tf.constant([[[1.0], [0.0], [0.0], [1.0]]]) + loss = tf.losses.hinge_loss(labels, logits) + # Examples 2 and 4 are on the wrong side of the hyperplane so they incur + # some (fairly large) loss. + self.assertAllClose(loss.eval(), 0.875, atol=1e-3) + + +class MeanSquaredErrorTest(tf.test.TestCase): + + def setUp(self): + self._predictions = tf.constant([4, 8, 12, 8, 1, 3], shape=(2, 3)) + self._labels = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3)) + + def testValueErrorThrownWhenWeightIsNone(self): + with self.test_session(): + with self.assertRaises(ValueError): + tf.losses.mean_squared_error( + self._predictions, self._predictions, weights=None) + + def testAllCorrectNoLossWeight(self): + loss = tf.losses.mean_squared_error( + self._predictions, self._predictions) + with self.test_session(): + self.assertAlmostEqual(0.0, loss.eval(), 3) + + def testNonZeroLoss(self): + loss = tf.losses.mean_squared_error( + self._labels, self._predictions) + with self.test_session(): + self.assertAlmostEqual(49.5, loss.eval(), 3) + + def testNonZeroLossWithPythonScalarWeight(self): + weights = 2.3 + loss = tf.losses.mean_squared_error( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(49.5 * weights, loss.eval(), 3) + + def testNonZeroLossWithScalarTensorWeight(self): + weights = 2.3 + loss = tf.losses.mean_squared_error( + self._labels, self._predictions, tf.constant(weights)) + with self.test_session(): + self.assertAlmostEqual(49.5 * weights, loss.eval(), 3) + + def testNonZeroLossWithOneDimBatchSpecificWeights(self): + weights = tf.constant([1.2, 3.4], shape=[2,]) + loss = tf.losses.mean_squared_error( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3) + + def testNonZeroLossWithTwoDimBatchSpecificWeights(self): + weights = tf.constant([1.2, 3.4], shape=[2, 1]) + loss = tf.losses.mean_squared_error( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3) + + def testNonZeroLossWithSampleSpecificWeights(self): + weights = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3]) + loss = tf.losses.mean_squared_error( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(587 / 5.0, loss.eval(), 3) + + def testNonZeroLossWithSampleSpecificWeightsMostZero(self): + weights = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3]) + loss = tf.losses.mean_squared_error( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(18.0, loss.eval(), 3) + + def testLossWithSampleSpecificWeightsAllZero(self): + weights = tf.zeros((2, 3)) + loss = tf.losses.mean_squared_error( + self._labels, self._predictions, weights) + with self.test_session(): + self.assertAlmostEqual(0.0, loss.eval(), 3) + + +class MeanPairwiseSquaresErrorTest(tf.test.TestCase): + + def setUp(self): + self._predictions = np.array([[4, 8, 12], + [8, 1, 3]]) + self._labels = np.array([[1, 9, 2], + [-5, -5, 7]]) + + batch_size, dims = self._labels.shape + + # Compute the expected loss 'manually'. + total = np.zeros((batch_size, 1)) + for b in range(batch_size): + for i in range(dims): + for j in range(dims): + x = self._predictions[b, i].item() - self._predictions[b, j].item() + y = self._labels[b, i].item() - self._labels[b, j].item() + tmp = (x-y) * (x-y) + total[b] += tmp + + self._expected_losses = np.divide(total, 9.0) + + def testValueErrorThrownWhenWeightIsNone(self): + with self.test_session(): + with self.assertRaises(ValueError): + tf.losses.mean_pairwise_squared_error( + predictions=tf.constant(self._labels), + labels=tf.constant(self._labels), + weights=None) + + def testAllCorrectNoLossWeight(self): + loss = tf.losses.mean_pairwise_squared_error( + predictions=tf.constant(self._labels), + labels=tf.constant(self._labels)) + with self.test_session(): + self.assertAlmostEqual(0.0, loss.eval(), 3) + + def testNonZeroLoss(self): + loss = tf.losses.mean_pairwise_squared_error( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels)) + with self.test_session(): + self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3) + + def testGradientWithZeroWeight(self): + with tf.Graph().as_default(): + tf.set_random_seed(0) + + inputs = tf.ones((2, 3)) + weights = tf.get_variable('weights', + shape=[3, 4], + initializer=tf.truncated_normal_initializer()) + predictions = tf.matmul(inputs, weights) + + optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.9) + loss = tf.losses.mean_pairwise_squared_error( + predictions, + predictions, + 0) + + gradients_to_variables = optimizer.compute_gradients(loss) + + init_op = tf.global_variables_initializer() + + with self.test_session() as sess: + sess.run(init_op) + for grad, _ in gradients_to_variables: + np_grad = sess.run(grad) + self.assertFalse(np.isnan(np_grad).any()) + + def testNonZeroLossWithPythonScalarWeight(self): + weights = 2.3 + loss = tf.losses.mean_pairwise_squared_error( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + weights=weights) + with self.test_session(): + self.assertAlmostEqual(weights * np.sum(self._expected_losses), + loss.eval(), 3) + + def testNonZeroLossWithScalarTensorWeight(self): + weights = 2.3 + loss = tf.losses.mean_pairwise_squared_error( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + weights=tf.constant(weights)) + with self.test_session(): + self.assertAlmostEqual(weights * np.sum(self._expected_losses), + loss.eval(), 3) + + def testNonZeroLossWithScalarZeroWeight(self): + weights = 0 + loss = tf.losses.mean_pairwise_squared_error( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + weights=tf.constant(weights)) + with self.test_session(): + self.assertAlmostEqual(0, loss.eval(), 3) + + def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self): + weights = 2.3 + tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape) + tf_labels = tf.placeholder(tf.float32, shape=self._labels.shape) + loss = tf.losses.mean_pairwise_squared_error( + predictions=tf_predictions, + labels=tf_labels, + weights=tf.constant(weights)) + with self.test_session() as sess: + loss = sess.run(loss, feed_dict={ + tf_predictions: self._predictions, + tf_labels: self._labels, + }) + self.assertAlmostEqual(weights * np.sum(self._expected_losses), loss, 3) + + def testNonZeroLossWithOneDimBatchSpecificWeights(self): + weights = np.asarray([2.0, 1.0]).reshape((2, 1)) + expected_losses = np.multiply(weights, self._expected_losses) + + loss = tf.losses.mean_pairwise_squared_error( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + weights=tf.constant(weights, shape=[2])) + with self.test_session(): + self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3) + + def testZeroLossWithOneDimBatchZeroWeights(self): + weights = np.asarray([0.0, 0.0]).reshape((2, 1)) + loss = tf.losses.mean_pairwise_squared_error( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + weights=tf.constant(weights, shape=[2])) + with self.test_session(): + self.assertAlmostEqual(0, loss.eval(), 3) + + def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self): + weights = np.asarray([1.2, 3.4]).reshape((2, 1)) + expected_losses = np.multiply(weights, self._expected_losses) + + tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape) + tf_labels = tf.placeholder(tf.int32, shape=self._labels.shape) + loss = tf.losses.mean_pairwise_squared_error( + predictions=tf_predictions, + labels=tf_labels, + weights=tf.constant(weights, shape=[2])) + + with self.test_session() as sess: + loss = sess.run(loss, feed_dict={ + tf_predictions: self._predictions, + tf_labels: self._labels, + }) + self.assertAlmostEqual(np.sum(expected_losses), loss, 3) + + def testLossWithAllZeroBatchSpecificWeights(self): + weights = np.zeros((2, 1)) + loss = tf.losses.mean_pairwise_squared_error( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + weights=tf.constant(weights, shape=[2])) + with self.test_session(): + self.assertAlmostEqual(0.0, loss.eval(), 3) + + +class CosineDistanceLossTest(tf.test.TestCase): + + def setUp(self): + self._predictions = np.asarray([[1, 0, 0], # Batch 1 + [0, 0, -1], + [1, 0, 0], # Batch 2 + [1, 0, 0], + [0, 0, -1], # Batch 3 + [1, 0, 0]]).reshape((3, 2, 3)) + + self._labels = np.asarray([[1, 0, 0], + [0, 0, 1], + [0, 1, 0], + [1, 0, 0], + [0, 0, 1], + [0, 1, 0]]).reshape((3, 2, 3)) + + def testValueErrorThrownWhenWeightIsNone(self): + with self.test_session(): + with self.assertRaises(ValueError): + tf.losses.cosine_distance( + predictions=tf.constant(self._labels), + labels=tf.constant(self._labels), + dim=2, + weights=None) + + def testAllCorrectNoWeights(self): + loss = tf.losses.cosine_distance( + predictions=tf.constant(self._labels), + labels=tf.constant(self._labels), + dim=2) + with self.test_session(): + self.assertAlmostEqual(0, loss.eval(), 5) + + def testPartiallyCorrectWithIntegerValues(self): + loss = tf.losses.cosine_distance( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + dim=2) + with self.test_session(): + self.assertAlmostEqual(1, loss.eval(), 5) + + def testPartiallyCorrectFloatingPointValues(self): + predictions = np.matrix(( + '0.819031913261206 0.567041924552012 0.087465312324590;' + '-0.665139432070255 -0.739487441769973 -0.103671883216994;' + '0.707106781186548 -0.707106781186548 0')) + labels = np.matrix(( + '0.819031913261206 0.567041924552012 0.087465312324590;' + '0.665139432070255 0.739487441769973 0.103671883216994;' + '0.707106781186548 0.707106781186548 0')) + + tf_preds = tf.constant(predictions, shape=(3, 1, 3), dtype=tf.float32) + tf_labels = tf.constant(labels, shape=(3, 1, 3), dtype=tf.float32) + loss = tf.losses.cosine_distance(tf_labels, tf_preds, dim=2) + + with self.test_session(): + self.assertAlmostEqual(1.0, loss.eval(), 5) + + def testSampleSpecificWeights(self): + loss = tf.losses.cosine_distance( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + dim=2, + weights=tf.constant([1, 0, 0])) + with self.test_session(): + self.assertEqual(1.0, loss.eval()) + + def testMeasurementSpecificWeights(self): + loss = tf.losses.cosine_distance( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + dim=2, + weights=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2))) + with self.test_session(): + self.assertEqual(3.0 / 4.0, loss.eval()) + + def testValueErrorThrownWithShapelessPlaceholder(self): + tf_predictions = tf.placeholder(tf.float32) + with self.test_session(): + with self.assertRaises(ValueError): + tf.losses.cosine_distance( + predictions=tf_predictions, + labels=tf.constant(self._labels), + dim=2, + weights=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2))) + + def testMeasurementSpecificWeightsWithPlaceholderWithShape(self): + tf_predictions = tf.placeholder(tf.float32, shape=self._labels.shape) + loss = tf.losses.cosine_distance( + predictions=tf_predictions, + labels=tf.constant(self._labels), + dim=2, + weights=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2))) + with self.test_session() as sess: + loss = sess.run(loss, feed_dict={tf_predictions: self._predictions}) + self.assertEqual(3.0 / 4.0, loss) + + def testZeroLossWhenAllSampleSpecificWeightsAreZero(self): + loss = tf.losses.cosine_distance( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + dim=2, + weights=tf.zeros((3,))) + with self.test_session(): + self.assertEqual(0, loss.eval()) + + def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self): + loss = tf.losses.cosine_distance( + predictions=tf.constant(self._predictions), + labels=tf.constant(self._labels), + dim=2, + weights=tf.zeros((3, 2))) + with self.test_session(): + self.assertEqual(0, loss.eval()) + + +class AddLossTest(tf.test.TestCase): + + def testNoCollectLossesBatch2(self): + logits = tf.constant([[1.2, 0.4, -1.0, -1.1]] * 2) + labels = tf.constant([[1.0, 0.0, 0.0, 1.0]] * 2) + self.assertFalse(tf.losses.get_losses()) + tf.losses.absolute_difference(logits, labels, loss_collection=None) + tf.losses.log_loss(logits, labels, loss_collection=None) + tf.losses.mean_squared_error(logits, labels, loss_collection=None) + tf.losses.sigmoid_cross_entropy(logits, labels, loss_collection=None) + tf.losses.softmax_cross_entropy(logits, labels, loss_collection=None) + self.assertFalse(tf.losses.get_losses()) + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/python/ops/losses/BUILD b/tensorflow/python/ops/losses/BUILD new file mode 100644 index 0000000000..4a2edd99b2 --- /dev/null +++ b/tensorflow/python/ops/losses/BUILD @@ -0,0 +1,39 @@ +package( + default_visibility = ["//tensorflow:internal"], + features = [ + "-layering_check", + "-parse_headers", + ], +) + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "losses", + srcs = [ + "__init__.py", + "losses.py", + "util.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/python/ops/losses/__init__.py b/tensorflow/python/ops/losses/__init__.py new file mode 100644 index 0000000000..3b0d0d8e5a --- /dev/null +++ b/tensorflow/python/ops/losses/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Loss functions and helpers to manipulate them. +""" + +# pylint: disable=wildcard-import +from tensorflow.python.ops.losses.losses import * +from tensorflow.python.ops.losses.util import * diff --git a/tensorflow/python/ops/losses/losses.py b/tensorflow/python/ops/losses/losses.py new file mode 100644 index 0000000000..e6c2a558b3 --- /dev/null +++ b/tensorflow/python/ops/losses/losses.py @@ -0,0 +1,588 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss operations for use in neural networks. + +Note: All the losses are added to the `GraphKeys.LOSSES` collection by default. + +@@absolute_difference +@@compute_weighted_loss +@@cosine_distance +@@hinge_loss +@@log_loss +@@mean_pairwise_squared_error +@@mean_squared_error +@@sigmoid_cross_entropy +@@softmax_cross_entropy +@@sparse_softmax_cross_entropy + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.losses import util + + +def _scale_losses(losses, weights): + """Computes the scaled loss. + + Args: + losses: A `Tensor` of size [batch_size, d1, ... dN]. + weights: A `Tensor` of size [1], [batch_size] or [batch_size, d1, ... dN]. + The `losses` are reduced (tf.reduce_sum) until its dimension matches + that of `weights` at which point the reduced `losses` are element-wise + multiplied by `weights` and a final reduce_sum is computed on the result. + Conceptually, this operation is equivalent to broadcasting (tiling) + `weights` to be the same size as `losses`, performing an element-wise + multiplication, and summing the result. + + Returns: + A scalar tf.float32 `Tensor` whose value represents the sum of the scaled + `losses`. + """ + # First, compute the sum of the losses over all elements: + start_index = max(0, weights.get_shape().ndims) + reduction_indices = list(range(start_index, losses.get_shape().ndims)) + reduced_losses = math_ops.reduce_sum(losses, + reduction_indices=reduction_indices) + reduced_losses = math_ops.mul(reduced_losses, weights) + return math_ops.reduce_sum(reduced_losses) + + +def _safe_div(numerator, denominator, name="value"): + """Computes a safe divide which returns 0 if the denominator is zero. + + Note that the function contains an additional conditional check that is + necessary for avoiding situations where the loss is zero causing NaNs to + creep into the gradient computation. + + Args: + numerator: An arbitrary `Tensor`. + denominator: A `Tensor` whose shape matches `numerator` and whose values are + assumed to be non-negative. + name: An optional name for the returned op. + + Returns: + The element-wise value of the numerator divided by the denominator. + """ + return math_ops.select( + math_ops.greater(denominator, 0), + math_ops.div(numerator, math_ops.select( + math_ops.equal(denominator, 0), + array_ops.ones_like(denominator), denominator)), + array_ops.zeros_like(numerator), + name=name) + + +def _safe_mean(losses, num_present): + """Computes a safe mean of the losses. + + Args: + losses: A tensor whose elements contain individual loss measurements. + num_present: The number of measurable losses in the tensor. + + Returns: + A scalar representing the mean of the losses. If `num_present` is zero, + then zero is returned. + """ + total_loss = math_ops.reduce_sum(losses) + return _safe_div(total_loss, num_present) + + +def _num_present(losses, weights, per_batch=False): + """Computes the number of elements in the loss function induced by `weights`. + + A given weights tensor induces different numbers of usable elements in the + `losses` tensor. The `weights` tensor is broadcast across `losses` for all + possible dimensions. For example, if `losses` is a tensor of dimension + [4, 5, 6, 3] and `weights` is a tensor of size [4, 5], then `weights` is, in + effect, tiled to match the size of `losses`. Following this effective tile, + the total number of present elements is the number of non-zero weights. + + Args: + losses: A tensor of size [batch_size, d1, ... dN]. + weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N. + per_batch: Whether to return the number of elements per batch or as a sum + total. + + Returns: + The number of present (non-zero) elements in the losses tensor. If + `per_batch` is True, the value is returned as a tensor of size + [batch_size]. Otherwise, a single scalar tensor is returned. + """ + # If weights is a scalar, its easy to compute: + if weights.get_shape().ndims == 0: + batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses), + [0], [1]), []) + num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)), + math_ops.to_float(batch_size)) + num_per_batch = math_ops.select(math_ops.equal(weights, 0), + 0.0, num_per_batch) + num_per_batch = math_ops.mul(array_ops.ones( + array_ops.reshape(batch_size, [1])), num_per_batch) + return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) + + # First, count the number of nonzero weights: + if weights.get_shape().ndims >= 1: + reduction_indices = list(range(1, weights.get_shape().ndims)) + num_nonzero_per_batch = math_ops.reduce_sum( + math_ops.to_float(math_ops.not_equal(weights, 0)), + reduction_indices=reduction_indices) + + # Next, determine the number of elements that weight would broadcast to: + broadcast_dims = array_ops.slice(array_ops.shape(losses), + [weights.get_shape().ndims], [-1]) + num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims)) + + num_per_batch = math_ops.mul(num_nonzero_per_batch, num_to_broadcast) + return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) + + +def compute_weighted_loss( + losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES): + """Computes the weighted loss. + + Args: + losses: A tensor of size [batch_size, d1, ... dN]. + weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N. + scope: the scope for the operations performed in computing the loss. + loss_collection: the loss will be added to these collections. + + Returns: + A scalar `Tensor` that returns the weighted loss. + + Raises: + ValueError: If `weights` is `None` or the shape is not compatible with + `losses`, or if the number of dimensions (rank) of either `losses` or + `weights` is missing. + """ + with ops.name_scope(scope, "weighted_loss", [losses, weights]): + losses = ops.convert_to_tensor(losses) + input_dtype = losses.dtype + losses = math_ops.to_float(losses) + weights = math_ops.to_float(ops.convert_to_tensor(weights)) + + if losses.get_shape().ndims is None: + raise ValueError("losses.get_shape().ndims cannot be None") + weights_shape = weights.get_shape() + if weights_shape.ndims is None: + raise ValueError("weight.get_shape().ndims cannot be None") + + if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1): + weights = array_ops.squeeze(weights, [-1]) + + total_loss = _scale_losses(losses, weights) + num_present = _num_present(losses, weights) + mean_loss = _safe_mean(total_loss, num_present) + # convert the result back to the input type + mean_loss = math_ops.cast(mean_loss, input_dtype) + util.add_loss(mean_loss, loss_collection) + return mean_loss + + +def absolute_difference( + labels, predictions, weights=1.0, scope=None, + loss_collection=ops.GraphKeys.LOSSES): + """Adds an Absolute Difference loss to the training procedure. + + `weights` acts as a coefficient for the loss. If a scalar is provided, then + the loss is simply scaled by the given value. If `weights` is a tensor of + size [batch_size], then the total loss for each sample of the batch is + rescaled by the corresponding element in the `weight` vector. If the shape of + `weight` matches the shape of `predictions`, then the loss of each + measurable element of `predictions` is scaled by the corresponding value of + `weight`. + + Args: + labels: The ground truth output tensor, same dimensions as 'predictions'. + predictions: The predicted outputs. + weights: Coefficients for the loss a scalar, a tensor of shape + [batch_size] or a tensor whose shape matches `predictions`. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + + Returns: + A scalar `Tensor` representing the loss value. + + Raises: + ValueError: If the shape of `predictions` doesn't match that of `labels` or + if the shape of `weight` is invalid. + """ + with ops.name_scope(scope, "absolute_difference", + [predictions, labels, weights]) as scope: + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + predictions = math_ops.to_float(predictions) + labels = math_ops.to_float(labels) + losses = math_ops.abs(math_ops.sub(predictions, labels)) + return compute_weighted_loss(losses, weights, scope, loss_collection) + + +def cosine_distance( + labels, predictions, dim=None, weights=1.0, scope=None, + loss_collection=ops.GraphKeys.LOSSES): + """Adds a cosine-distance loss to the training procedure. + + Note that the function assumes that `predictions` and `labels` are already + unit-normalized. + + Args: + labels: A `Tensor` whose shape matches 'predictions' + predictions: An arbitrary matrix. + dim: The dimension along which the cosine distance is computed. + weights: Coefficients for the loss a scalar, a tensor of shape + [batch_size] or a tensor whose shape matches `predictions`. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + + Returns: + A scalar `Tensor` representing the loss value. + + Raises: + ValueError: If `predictions` shape doesn't match `labels` shape, or + `weights` is `None`. + """ + if dim is None: + raise ValueError("`dim` cannot be None.") + with ops.name_scope(scope, "cosine_distance_loss", + [predictions, labels, weights]) as scope: + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + + predictions = math_ops.to_float(predictions) + labels = math_ops.to_float(labels) + + radial_diffs = math_ops.mul(predictions, labels) + losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,]) + return compute_weighted_loss(losses, weights, scope, loss_collection) + + +def hinge_loss(labels, logits, weights=1.0, scope=None, + loss_collection=ops.GraphKeys.LOSSES): + """Adds a hinge loss to the training procedure. + + Args: + labels: The ground truth output tensor. Its shape should match the shape of + logits. The values of the tensor are expected to be 0.0 or 1.0. + logits: The logits, a float tensor. + weights: Coefficients for the loss a scalar, a tensor of shape + [batch_size] or a tensor whose shape matches `predictions`. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which the loss will be added. + + Returns: + A scalar `Tensor` of the loss value. + + Raises: + ValueError: If the shapes of `logits` and `labels` don't match. + """ + with ops.name_scope(scope, "hinge_loss", [logits, labels]) as scope: + logits.get_shape().assert_is_compatible_with(labels.get_shape()) + # We first need to convert binary labels to -1/1 labels (as floats). + labels = math_ops.to_float(labels) + all_ones = array_ops.ones_like(labels) + labels = math_ops.sub(2 * labels, all_ones) + losses = nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits))) + return compute_weighted_loss(losses, weights, scope, loss_collection) + + +def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None, + loss_collection=ops.GraphKeys.LOSSES): + """Adds a Log Loss term to the training procedure. + + `weight` acts as a coefficient for the loss. If a scalar is provided, then the + loss is simply scaled by the given value. If `weight` is a tensor of size + [batch_size], then the total loss for each sample of the batch is rescaled + by the corresponding element in the `weight` vector. If the shape of + `weight` matches the shape of `predictions`, then the loss of each + measurable element of `predictions` is scaled by the corresponding value of + `weight`. + + Args: + labels: The ground truth output tensor, same dimensions as 'predictions'. + predictions: The predicted outputs. + weights: Coefficients for the loss a scalar, a tensor of shape + [batch_size] or a tensor whose shape matches `predictions`. + epsilon: A small increment to add to avoid taking a log of zero. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which the loss will be added. + + Returns: + A scalar `Tensor` representing the loss value. + + Raises: + ValueError: If the shape of `predictions` doesn't match that of `labels` or + if the shape of `weight` is invalid. + """ + with ops.name_scope(scope, "log_loss", + [predictions, labels, weights]) as scope: + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + predictions = math_ops.to_float(predictions) + labels = math_ops.to_float(labels) + losses = -math_ops.mul( + labels, + math_ops.log(predictions + epsilon)) - math_ops.mul( + (1 - labels), math_ops.log(1 - predictions + epsilon)) + return compute_weighted_loss(losses, weights, scope, loss_collection) + + +def mean_pairwise_squared_error(labels, predictions, weights=1.0, scope=None, + loss_collection=ops.GraphKeys.LOSSES): + """Adds a pairwise-errors-squared loss to the training procedure. + + Unlike `mean_squared_error`, which is a measure of the differences between + corresponding elements of `predictions` and `labels`, + `mean_pairwise_squared_error` is a measure of the differences between pairs of + corresponding elements of `predictions` and `labels`. + + For example, if `labels`=[a, b, c] and `predictions`=[x, y, z], there are + three pairs of differences are summed to compute the loss: + loss = [ ((a-b) - (x-y)).^2 + ((a-c) - (x-z)).^2 + ((b-c) - (y-z)).^2 ] / 3 + + Note that since the inputs are of size [batch_size, d0, ... dN], the + corresponding pairs are computed within each batch sample but not across + samples within a batch. For example, if `predictions` represents a batch of + 16 grayscale images of dimension [batch_size, 100, 200], then the set of pairs + is drawn from each image, but not across images. + + `weight` acts as a coefficient for the loss. If a scalar is provided, then the + loss is simply scaled by the given value. If `weight` is a tensor of size + [batch_size], then the total loss for each sample of the batch is rescaled + by the corresponding element in the `weight` vector. + + Args: + labels: The ground truth output tensor, whose shape must match the shape of + the `predictions` tensor. + predictions: The predicted outputs, a tensor of size [batch_size, d0, .. dN] + where N+1 is the total number of dimensions in `predictions`. + weights: Coefficients for the loss a scalar, a tensor of shape [batch_size] + or a tensor whose shape matches `predictions`. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which the loss will be added. + + Returns: + A scalar `Tensor` representing the loss value. + + Raises: + ValueError: If the shape of `predictions` doesn't match that of `labels` or + if the shape of `weight` is invalid. + """ + with ops.name_scope(scope, "mean_pairwise_squared_error", + [predictions, labels, weights]) as scope: + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + predictions = math_ops.to_float(predictions) + labels = math_ops.to_float(labels) + weights = math_ops.to_float(ops.convert_to_tensor(weights)) + + diffs = math_ops.sub(predictions, labels) + + # Need to verify here since the function doesn't use compute_weighted_loss + if diffs.get_shape().ndims is None: + raise ValueError("diffs.get_shape().ndims cannot be None") + if weights.get_shape().ndims is None: + raise ValueError("weights.get_shape().ndims cannot be None") + + reduction_indices = list(range(1, diffs.get_shape().ndims)) + + sum_squares_diff_per_batch = math_ops.reduce_sum( + math_ops.square(diffs), + reduction_indices=reduction_indices) + num_present_per_batch = _num_present(diffs, weights, per_batch=True) + + term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, + num_present_per_batch) + + sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices) + term2 = 2.0 * _safe_div(math_ops.square(sum_diff), + math_ops.square(num_present_per_batch)) + + loss = _scale_losses(term1 - term2, weights) + + mean_loss = math_ops.select(math_ops.reduce_sum(num_present_per_batch) > 0, + loss, + array_ops.zeros_like(loss), + name="value") + util.add_loss(mean_loss, loss_collection) + return mean_loss + + +def mean_squared_error(labels, predictions, weights=1.0, scope=None, + loss_collection=ops.GraphKeys.LOSSES): + """Adds a Sum-of-Squares loss to the training procedure. + + `weight` acts as a coefficient for the loss. If a scalar is provided, then the + loss is simply scaled by the given value. If `weight` is a tensor of size + [batch_size], then the total loss for each sample of the batch is rescaled + by the corresponding element in the `weight` vector. If the shape of + `weight` matches the shape of `predictions`, then the loss of each + measurable element of `predictions` is scaled by the corresponding value of + `weight`. + + Args: + labels: The ground truth output tensor, same dimensions as 'predictions'. + predictions: The predicted outputs. + weights: Coefficients for the loss a scalar, a tensor of shape + [batch_size] or a tensor whose shape matches `predictions`. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which the loss will be added. + + Returns: + A scalar `Tensor` representing the loss value. + + Raises: + ValueError: If the shape of `predictions` doesn't match that of `labels` or + if the shape of `weight` is invalid. + """ + with ops.name_scope(scope, "mean_squared_error", + [predictions, labels, weights]) as scope: + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + predictions = math_ops.to_float(predictions) + labels = math_ops.to_float(labels) + losses = math_ops.square(math_ops.sub(predictions, labels)) + return compute_weighted_loss(losses, weights, scope, loss_collection) + + +def sigmoid_cross_entropy( + multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None, + loss_collection=ops.GraphKeys.LOSSES): + """Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits. + + `weight` acts as a coefficient for the loss. If a scalar is provided, + then the loss is simply scaled by the given value. If `weight` is a + tensor of size [`batch_size`], then the loss weights apply to each + corresponding sample. + + If `label_smoothing` is nonzero, smooth the labels towards 1/2: + + new_multiclass_labels = multiclass_labels * (1 - label_smoothing) + + 0.5 * label_smoothing + + Args: + multi_class_labels: [batch_size, num_classes] target labels in (0, 1). + logits: [batch_size, num_classes] logits outputs of the network . + weights: Coefficients for the loss. The tensor must be a scalar, a tensor of + shape [batch_size] or shape [batch_size, num_classes]. + label_smoothing: If greater than 0 then smooth the labels. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which the loss will be added. + + Returns: + A scalar `Tensor` representing the loss value. + + Raises: + ValueError: If the shape of `logits` doesn't match that of + `multi_class_labels` or if the shape of `weight` is invalid, or if + `weight` is None. + """ + with ops.name_scope(scope, "sigmoid_cross_entropy_loss", + [logits, multi_class_labels, weights]) as scope: + logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape()) + + multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype) + + if label_smoothing > 0: + multi_class_labels = (multi_class_labels * (1 - label_smoothing) + + 0.5 * label_smoothing) + + losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels, + name="xentropy") + return compute_weighted_loss(losses, weights, scope, loss_collection) + + +def softmax_cross_entropy( + onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None, + loss_collection=ops.GraphKeys.LOSSES): + """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits. + + `weight` acts as a coefficient for the loss. If a scalar is provided, + then the loss is simply scaled by the given value. If `weight` is a + tensor of size [`batch_size`], then the loss weights apply to each + corresponding sample. + + If `label_smoothing` is nonzero, smooth the labels towards 1/num_classes: + new_onehot_labels = onehot_labels * (1 - label_smoothing) + + label_smoothing / num_classes + + Args: + onehot_labels: [batch_size, num_classes] target one_hot_encoded labels. + logits: [batch_size, num_classes] logits outputs of the network . + weights: Coefficients for the loss. The tensor must be a scalar or a tensor + of shape [batch_size]. + label_smoothing: If greater than 0 then smooth the labels. + scope: the scope for the operations performed in computing the loss. + loss_collection: collection to which the loss will be added. + + Returns: + A scalar `Tensor` representing the loss value. + + Raises: + ValueError: If the shape of `logits` doesn't match that of `onehot_labels` + or if the shape of `weight` is invalid or if `weight` is None. + """ + with ops.name_scope(scope, "softmax_cross_entropy_loss", + [logits, onehot_labels, weights]) as scope: + logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape()) + + onehot_labels = math_ops.cast(onehot_labels, logits.dtype) + + if label_smoothing > 0: + num_classes = math_ops.cast( + array_ops.shape(onehot_labels)[1], logits.dtype) + smooth_positives = 1.0 - label_smoothing + smooth_negatives = label_smoothing / num_classes + onehot_labels = onehot_labels * smooth_positives + smooth_negatives + + losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels, + name="xentropy") + return compute_weighted_loss(losses, weights, scope, loss_collection) + + +def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None, + loss_collection=ops.GraphKeys.LOSSES): + """Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`. + + `weight` acts as a coefficient for the loss. If a scalar is provided, + then the loss is simply scaled by the given value. If `weight` is a + tensor of size [`batch_size`], then the loss weights apply to each + corresponding sample. + + Args: + labels: [batch_size, 1] or [batch_size] target labels of dtype `int32` or + `int64` in the range `[0, num_classes)`. + logits: [batch_size, num_classes] logits outputs of the network . + weights: Coefficients for the loss. The tensor must be a scalar or a tensor + of shape [batch_size] or [batch_size, 1]. + scope: the scope for the operations performed in computing the loss. + loss_collection: collection to which the loss will be added. + + Returns: + A scalar `Tensor` representing the loss value. + + Raises: + ValueError: If the shapes of logits, labels, and weight are incompatible, or + if `weight` is None. + """ + with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss", + [logits, labels, weights]) as scope: + labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]]) + weights = array_ops.squeeze(weights) + + losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels, + name="xentropy") + return compute_weighted_loss(losses, weights, scope, loss_collection) diff --git a/tensorflow/python/ops/losses/util.py b/tensorflow/python/ops/losses/util.py new file mode 100644 index 0000000000..aaf324891f --- /dev/null +++ b/tensorflow/python/ops/losses/util.py @@ -0,0 +1,88 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for manipulating the loss collections. + + +@@add_loss +@@get_losses +@@get_regularization_losses +@@get_total_loss + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops + + +def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES): + """Adds a externally defined loss to the collection of losses. + + Args: + loss: A loss `Tensor`. + loss_collection: Optional collection to add the loss to. + """ + if loss_collection: + ops.add_to_collection(loss_collection, loss) + + +def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES): + """Gets the list of losses from the loss_collection. + + Args: + scope: an optional scope for filtering the losses to return. + loss_collection: Optional losses collection. + + Returns: + a list of loss tensors. + """ + return ops.get_collection(loss_collection, scope) + + +def get_regularization_losses(scope=None): + """Gets the regularization losses. + + Args: + scope: an optional scope for filtering the losses to return. + + Returns: + A list of loss variables. + """ + return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope) + + +def get_total_loss(add_regularization_losses=True, name="total_loss"): + """Returns a tensor whose value represents the total loss. + + Notice that the function adds the given losses to the regularization losses. + + Args: + add_regularization_losses: A boolean indicating whether or not to use the + regularization losses in the sum. + name: The name of the returned tensor. + + Returns: + A `Tensor` whose value represents the total loss. + + Raises: + ValueError: if `losses` is not iterable. + """ + losses = get_losses() + if add_regularization_losses: + losses += get_regularization_losses() + return math_ops.add_n(losses, name=name) |