aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-02 08:56:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-02 09:06:39 -0800
commit896285a8dca7bddbf328b3728683acf619f26c13 (patch)
tree863dcd57b331cfba86f2d73a37f7e90b1d4dede5
parentd8b037828ab66f25a7b526848cdc3fa9b3b9f198 (diff)
Moves tf.contrib.losses into core, with changes.
Change: 140855283
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py16
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/__init__.py10
-rw-r--r--tensorflow/python/kernel_tests/BUILD7
-rw-r--r--tensorflow/python/kernel_tests/losses_test.py1142
-rw-r--r--tensorflow/python/ops/losses/BUILD39
-rw-r--r--tensorflow/python/ops/losses/__init__.py21
-rw-r--r--tensorflow/python/ops/losses/losses.py588
-rw-r--r--tensorflow/python/ops/losses/util.py88
9 files changed, 1907 insertions, 5 deletions
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index c17b251d3e..780def4269 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -28,7 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
-
+from tensorflow.python.util.deprecation import deprecated
__all__ = ["absolute_difference",
"add_loss",
@@ -141,6 +141,7 @@ def _safe_mean(losses, num_present):
return _safe_div(total_loss, num_present)
+@deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
@deprecated_args(
"2016-11-25", "`weight` is being deprecated, use `weights`.", "weight")
def compute_weighted_loss(
@@ -235,6 +236,7 @@ def _num_present(losses, weights, per_batch=False):
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
+@deprecated("2016-12-30", "Use tf.losses.add_loss instead.")
@add_arg_scope
def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
"""Adds a externally defined loss to the collection of losses.
@@ -247,6 +249,7 @@ def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
ops.add_to_collection(loss_collection, loss)
+@deprecated("2016-12-30", "Use tf.losses.get_losses instead.")
def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
"""Gets the list of losses from the loss_collection.
@@ -260,6 +263,7 @@ def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
return ops.get_collection(loss_collection, scope)
+@deprecated("2016-12-30", "Use tf.losses.get_regularization_losses instead.")
def get_regularization_losses(scope=None):
"""Gets the regularization losses.
@@ -272,6 +276,7 @@ def get_regularization_losses(scope=None):
return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
+@deprecated("2016-12-30", "Use tf.losses.get_total_loss instead.")
def get_total_loss(add_regularization_losses=True, name="total_loss"):
"""Returns a tensor whose value represents the total loss.
@@ -294,6 +299,7 @@ def get_total_loss(add_regularization_losses=True, name="total_loss"):
return math_ops.add_n(losses, name=name)
+@deprecated("2016-12-30", "Use tf.losses.absolute_difference instead.")
@deprecated_args(
"2016-11-25",
"`targets` is being deprecated, use `labels`."
@@ -339,6 +345,7 @@ def absolute_difference(
return compute_weighted_loss(losses, weights, scope=scope)
+@deprecated("2016-12-30", "Use tf.losses.sigmoid_cross_entropy instead.")
@deprecated_args(
"2016-11-25", "`weight` is being deprecated, use `weights`", "weight")
def sigmoid_cross_entropy(
@@ -389,6 +396,7 @@ def sigmoid_cross_entropy(
return compute_weighted_loss(losses, weights, scope=scope)
+@deprecated("2016-12-30", "Use tf.losses.softmax_cross_entropy instead.")
@deprecated_args(
"2016-11-25", "`weight` is being deprecated, use `weights`", "weight")
def softmax_cross_entropy(
@@ -440,6 +448,7 @@ def softmax_cross_entropy(
return compute_weighted_loss(losses, weights, scope=scope)
+@deprecated("2016-12-30", "Use tf.losses.sparse_softmax_cross_entropy instead.")
@deprecated_args(
"2016-11-25", "`weight` is being deprecated, use `weights`", "weight")
def sparse_softmax_cross_entropy(
@@ -479,6 +488,7 @@ def sparse_softmax_cross_entropy(
return compute_weighted_loss(losses, weights, scope=scope)
+@deprecated("2016-12-30", "Use tf.losses.log_loss instead.")
@deprecated_args(
"2016-11-25",
"`targets` is being deprecated, use `labels`."
@@ -528,6 +538,7 @@ def log_loss(
return compute_weighted_loss(losses, weights, scope=scope)
+@deprecated("2016-12-30", "Use tf.losses.hinge_loss instead.")
@deprecated_args(
"2016-11-25", "`target` is being deprecated, use `labels`.", "target")
def hinge_loss(logits, labels=None, scope=None, target=None):
@@ -557,6 +568,7 @@ def hinge_loss(logits, labels=None, scope=None, target=None):
return nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits)))
+@deprecated("2016-12-30", "Use tf.losses.mean_squared_error instead.")
@deprecated_args(
"2016-11-25",
"`targets` is being deprecated, use `labels`."
@@ -602,6 +614,7 @@ def mean_squared_error(
return compute_weighted_loss(losses, weights, scope=scope)
+@deprecated("2016-12-30", "Use tf.losses.mean_pairwise_squared_error instead.")
@deprecated_args(
"2016-11-25",
"`targets` is being deprecated, use `labels`."
@@ -691,6 +704,7 @@ def mean_pairwise_squared_error(
return mean_loss
+@deprecated("2016-12-30", "Use tf.losses.cosine_distance instead.")
@deprecated_args(
"2016-11-25",
"`targets` is being deprecated, use `labels`."
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 180aa291ec..c5c2f77378 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -43,6 +43,7 @@ py_library(
":training",
":ops",
":test_ops",
+ "//tensorflow/python/ops/losses",
"//tensorflow/python/debug:debug_py",
] + if_not_windows([
"//tensorflow/contrib:contrib_py",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index cd2f4fa328..e323c9b6a4 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -87,6 +87,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import resources
from tensorflow.python.ops import sdca_ops as sdca
from tensorflow.python.ops import image_ops as image
+from tensorflow.python.ops import losses
from tensorflow.python.ops import sets
from tensorflow.python.user_ops import user_ops
from tensorflow.python.util import compat
@@ -218,6 +219,7 @@ _allowed_symbols.extend([
'graph_util',
'image',
'logging',
+ 'losses',
'newaxis',
'nn',
'python_io',
@@ -245,10 +247,10 @@ _allowed_symbols.extend([
remove_undocumented(__name__, _allowed_symbols,
[framework_lib, array_ops, client_lib, check_ops,
compat, constant_op, control_flow_ops, functional_ops,
- histogram_ops, io_ops, math_ops, nn, resource_loader,
- resources, sets, script_ops, session_ops, sparse_ops,
- state_ops, string_ops, summary, tensor_array_ops, train,
- layers])
+ histogram_ops, io_ops, losses, math_ops, nn,
+ resource_loader, resources, sets, script_ops, session_ops,
+ sparse_ops, state_ops, string_ops, summary,
+ tensor_array_ops, train, layers])
# Special dunders that we choose to export:
_exported_dunders = set([
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index fccb225a2b..3185b1fd06 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -194,6 +194,13 @@ tf_py_test(
)
tf_py_test(
+ name = "losses_test",
+ size = "small",
+ srcs = ["losses_test.py"],
+ additional_deps = ["//tensorflow:tensorflow_py"],
+)
+
+tf_py_test(
name = "matrix_inverse_op_test",
size = "small",
srcs = ["matrix_inverse_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py
new file mode 100644
index 0000000000..2393124ba3
--- /dev/null
+++ b/tensorflow/python/kernel_tests/losses_test.py
@@ -0,0 +1,1142 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for losses."""
+# pylint: disable=unused-import,g-bad-import-order
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+# pylint: enable=unused-import
+
+import numpy as np
+import tensorflow as tf
+
+
+class AbsoluteDifferenceLossTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._predictions = tf.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
+ self._labels = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+
+ def testValueErrorThrownWhenWeightIsNone(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.losses.absolute_difference(
+ self._predictions, self._predictions, weights=None)
+
+ def testAllCorrectNoLossWeight(self):
+ loss = tf.losses.absolute_difference(
+ self._predictions, self._predictions)
+ with self.test_session():
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+ def testNonZeroLoss(self):
+ loss = tf.losses.absolute_difference(
+ self._labels, self._predictions)
+ with self.test_session():
+ self.assertAlmostEqual(5.5, loss.eval(), 3)
+
+ def testNonZeroLossWithPythonScalarWeight(self):
+ weights = 2.3
+ loss = tf.losses.absolute_difference(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
+
+ def testNonZeroLossWithScalarTensorWeight(self):
+ weights = 2.3
+ loss = tf.losses.absolute_difference(
+ self._labels, self._predictions, tf.constant(weights))
+ with self.test_session():
+ self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
+
+ def testNonZeroLossWithOneDimBatchSpecificWeights(self):
+ weights = tf.constant([1.2, 0.0], shape=[2,])
+ loss = tf.losses.absolute_difference(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(5.6, loss.eval(), 3)
+
+ def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
+ weights = tf.constant([1.2, 0.0], shape=[2, 1])
+ loss = tf.losses.absolute_difference(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(5.6, loss.eval(), 3)
+
+ def testNonZeroLossWithSampleSpecificWeights(self):
+ weights = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
+ loss = tf.losses.absolute_difference(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(16.6, loss.eval(), 3)
+
+ def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
+ weights = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
+ loss = tf.losses.absolute_difference(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(6.0, loss.eval(), 3)
+
+ def testLossWithSampleSpecificWeightsAllZero(self):
+ weights = tf.zeros((2, 3))
+ loss = tf.losses.absolute_difference(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+
+class SoftmaxCrossEntropyLossTest(tf.test.TestCase):
+
+ def testNoneWeightRaisesValueError(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]])
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.losses.softmax_cross_entropy(labels, logits, weights=None)
+
+ def testAllCorrect(self):
+ with self.test_session():
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]])
+ loss = tf.losses.softmax_cross_entropy(labels, logits)
+ self.assertEquals('softmax_cross_entropy_loss/value', loss.op.name)
+ self.assertAlmostEqual(loss.eval(), 0.0, 3)
+
+ def testAllWrong(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+
+ with self.test_session():
+ loss = tf.losses.softmax_cross_entropy(labels, logits)
+ self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
+ self.assertAlmostEqual(loss.eval(), 10.0, 3)
+
+ def testNonZeroLossWithPythonScalarWeight(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ weights = 2.3
+ with self.test_session():
+ loss = tf.losses.softmax_cross_entropy(labels, logits, weights)
+ self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
+
+ def testNonZeroLossWithScalarTensorWeight(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ weights = 2.3
+ with self.test_session():
+ loss = tf.losses.softmax_cross_entropy(
+ labels, logits, tf.constant(weights))
+ self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
+
+ def testNonZeroLossWithOneDimBatchSpecificWeights(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ weights = tf.constant([1.2, 3.4, 5.6], shape=[3])
+ with self.test_session():
+ loss = tf.losses.softmax_cross_entropy(labels, logits, weights)
+ self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
+
+ def testAllWrongAllWeightsMissing(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ weights = tf.constant([0, 0, 0], shape=[3])
+ with self.test_session():
+ loss = tf.losses.softmax_cross_entropy(labels, logits, weights)
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+ def testSomeWeightsMissing(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ weights = tf.constant([1.2, 0, 0], shape=[3])
+ with self.test_session():
+ loss = tf.losses.softmax_cross_entropy(labels, logits, weights)
+ self.assertAlmostEqual(12.0, loss.eval(), 3)
+
+ def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
+ with self.test_session():
+ logits = tf.constant([[100.0, -100.0, -100.0],
+ [-100.0, 100.0, -100.0],
+ [-100.0, -100.0, 100.0]])
+ labels = tf.constant([[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]])
+ weights = tf.constant([[3, 4, 5],
+ [2, 6, 0],
+ [8, 0, 1]])
+
+ with self.assertRaises(ValueError):
+ tf.losses.softmax_cross_entropy(
+ labels, logits, weights=weights).eval()
+
+ def testSoftmaxLabelSmoothing(self):
+ with self.test_session():
+ # Softmax Cross Entropy Loss is:
+ # -\sum_i p_i \log q_i
+ # where for a softmax activation
+ # \log q_i = x_i - \log \sum_j \exp x_j
+ # = x_i - x_max - \log \sum_j \exp (x_j - x_max)
+ # For our activations, [100, -100, -100] the log partion function becomes
+ # \log ( exp(0) + exp(-200) + exp(-200) ) = 0
+ # so our log softmaxes become: [0, -200, -200]
+ # so our cross entropy loss is:
+ # -(1 - L + L/n) * 0 + 400 * L/n = 400 L/n
+ logits = tf.constant([[100.0, -100.0, -100.0]])
+ labels = tf.constant([[1, 0, 0]])
+ label_smoothing = 0.1
+ loss = tf.losses.softmax_cross_entropy(
+ labels, logits, label_smoothing=label_smoothing)
+ self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
+ expected_value = 400.0 * label_smoothing / 3.0
+ self.assertAlmostEqual(loss.eval(), expected_value, 3)
+
+
+class SparseSoftmaxCrossEntropyLossTest(tf.test.TestCase):
+
+ def testNoneWeightRaisesValueError(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0], [1], [2]])
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights=None)
+
+ def testAllCorrectInt32Labels(self):
+ with self.test_session():
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0], [1], [2]], dtype=tf.int32)
+ loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
+ self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
+ self.assertAlmostEqual(loss.eval(), 0.0, 3)
+
+ def testAllCorrectInt64Labels(self):
+ with self.test_session():
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[0], [1], [2]], dtype=tf.int64)
+ loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
+ self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
+ self.assertAlmostEqual(loss.eval(), 0.0, 3)
+
+ def testAllCorrectNonColumnLabels(self):
+ with self.test_session():
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([0, 1, 2])
+ loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
+ self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
+ self.assertAlmostEqual(loss.eval(), 0.0, 3)
+
+ def testAllWrongInt32Labels(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[2], [0], [1]], dtype=tf.int32)
+
+ with self.test_session():
+ loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
+ self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
+ self.assertAlmostEqual(loss.eval(), 10.0, 3)
+
+ def testAllWrongInt64Labels(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[2], [0], [1]], dtype=tf.int64)
+
+ with self.test_session():
+ loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
+ self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
+ self.assertAlmostEqual(loss.eval(), 10.0, 3)
+
+ def testAllWrongNonColumnLabels(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([2, 0, 1])
+
+ with self.test_session():
+ loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
+ self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
+ self.assertAlmostEqual(loss.eval(), 10.0, 3)
+
+ def testNonZeroLossWithPythonScalarWeight(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[2], [0], [1]])
+ weights = 2.3
+ with self.test_session():
+ loss = tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights)
+ self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
+
+ def testNonZeroLossWithScalarTensorWeight(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[2], [0], [1]])
+ weights = 2.3
+ with self.test_session():
+ loss = tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, tf.constant(weights))
+ self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
+
+ def testNonZeroLossWithOneDimBatchSpecificWeights(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[2], [0], [1]])
+ weights = tf.constant([1.2, 3.4, 5.6], shape=[3])
+ with self.test_session():
+ loss = tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights)
+ self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
+
+ def testNonZeroLossWithColumnWeights(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[2], [0], [1]])
+ weights = tf.constant([[1.2], [3.4], [5.6]])
+ with self.test_session():
+ loss = tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights)
+ self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
+
+ def testAllWrongAllWeightsMissing(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[2], [0], [1]])
+ weights = tf.constant([0, 0, 0], shape=[3])
+ with self.test_session():
+ loss = tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights)
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+ def testSomeWeightsMissing(self):
+ logits = tf.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = tf.constant([[2], [0], [1]])
+ weights = tf.constant([1.2, 0, 0], shape=[3])
+ with self.test_session():
+ loss = tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights)
+ self.assertAlmostEqual(12.0, loss.eval(), 3)
+
+ def testMeasurementSpecificWeightsRaisesException(self):
+ with self.test_session():
+ logits = tf.constant([[100.0, -100.0, -100.0],
+ [-100.0, 100.0, -100.0],
+ [-100.0, -100.0, 100.0]])
+ labels = tf.constant([[0], [1], [2]])
+ weights = tf.constant([[3, 4, 5],
+ [2, 6, 0],
+ [8, 0, 1]])
+
+ with self.assertRaises(ValueError):
+ tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights=weights).eval()
+
+ def testInconsistentWeightSizeRaisesException(self):
+ """The weight tensor has incorrect number of elements."""
+ with self.test_session():
+ logits = tf.constant([[100.0, -100.0, -100.0],
+ [-100.0, 100.0, -100.0],
+ [-100.0, -100.0, 100.0]])
+ labels = tf.constant([[0], [1], [2]])
+ weights = tf.constant([1.2, 3.4, 5.6, 7.8])
+
+ with self.assertRaises(ValueError):
+ tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights=weights).eval()
+
+ def testInconsistentLabelSizeRaisesException(self):
+ """The label tensor has incorrect number of elements."""
+ with self.test_session():
+ logits = tf.constant([[100.0, -100.0, -100.0],
+ [-100.0, 100.0, -100.0],
+ [-100.0, -100.0, 100.0]])
+ labels = tf.constant([[0], [1], [2], [3]])
+ weights = tf.constant([1.2, 3.4, 5.6])
+
+ with self.assertRaises(ValueError):
+ tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights=weights).eval()
+
+ def testInconsistentWeightShapeRaisesException(self):
+ """The weight tensor has incorrect shape."""
+ with self.test_session():
+ logits = tf.constant([[100.0, -100.0, -100.0, -100.0],
+ [-100.0, 100.0, -100.0, -100.0],
+ [-100.0, -100.0, 100.0, -100.0],
+ [-100.0, -100.0, -100.0, 100.0]])
+ labels = tf.constant([[0], [1], [2], [3]])
+ weights = tf.constant([[1.2, 3.4], [5.6, 7.8]])
+
+ with self.assertRaises(ValueError):
+ tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights=weights).eval()
+
+ def testInconsistentLabelShapeRaisesException(self):
+ """The label tensor has incorrect shape."""
+ with self.test_session():
+ logits = tf.constant([[100.0, -100.0, -100.0, -100.0],
+ [-100.0, 100.0, -100.0, -100.0],
+ [-100.0, -100.0, 100.0, -100.0],
+ [-100.0, -100.0, -100.0, 100.0]])
+ labels = tf.constant([[0, 1], [2, 3]])
+ weights = tf.constant([1.2, 3.4, 5.6, 7.8])
+
+ with self.assertRaises(tf.errors.InvalidArgumentError):
+ tf.losses.sparse_softmax_cross_entropy(
+ labels, logits, weights=weights).eval()
+
+
+class SigmoidCrossEntropyLossTest(tf.test.TestCase):
+
+ def testAllCorrectSigmoid(self):
+ with self.test_session():
+ logits = tf.constant([[100.0, -100.0, -100.0],
+ [-100.0, 100.0, -100.0],
+ [-100.0, -100.0, 100.0]])
+ labels = tf.constant([[1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]])
+ loss = tf.losses.sigmoid_cross_entropy(labels, logits)
+ self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+ def testLossWithSingleDimPlaceholderForLogitsAndWeights1(self):
+ logits = tf.placeholder(tf.float32, shape=(None, 1))
+ labels = tf.placeholder(tf.float32, shape=(None, 1))
+ weights = tf.ones_like(logits, dtype=tf.float32)
+
+ loss = tf.losses.sigmoid_cross_entropy(labels, logits, weights)
+
+ with self.test_session() as sess:
+ loss = sess.run(loss, feed_dict={
+ logits: np.ones((32, 1)),
+ labels: np.ones((32, 1)),
+ })
+ self.assertAlmostEqual(0.313, loss, 3)
+
+ def testLossWithSingleDimPlaceholderForLogitsAndWeights2(self):
+ logits = tf.placeholder(tf.float32, shape=(None, 2))
+ labels = tf.placeholder(tf.float32, shape=(None, 2))
+ weights = tf.ones_like(logits, dtype=tf.float32)
+
+ loss = tf.losses.sigmoid_cross_entropy(labels, logits, weights)
+
+ with self.test_session() as sess:
+ loss = sess.run(loss, feed_dict={
+ logits: np.ones((32, 2)),
+ labels: np.ones((32, 2)),
+ })
+ self.assertAlmostEqual(0.313, loss, 3)
+
+ def testAllWrongSigmoid(self):
+ with self.test_session():
+ logits = tf.constant([[100.0, -100.0, -100.0],
+ [-100.0, 100.0, -100.0],
+ [-100.0, -100.0, 100.0]])
+ labels = tf.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ loss = tf.losses.sigmoid_cross_entropy(labels, logits)
+ self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
+ self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
+
+ def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
+ with self.test_session():
+ logits = tf.constant([[100.0, -100.0, -100.0],
+ [-100.0, 100.0, -100.0],
+ [-100.0, -100.0, 100.0]])
+ labels = tf.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ weights = tf.constant([[3, 4, 5],
+ [2, 6, 0],
+ [8, 0, 1]])
+ loss = tf.losses.sigmoid_cross_entropy(
+ labels, logits, weights)
+ self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
+ self.assertAlmostEqual(1700.0 / 7.0, loss.eval(), 3)
+
+ def testMultiCorrectSigmoid(self):
+ logits = tf.constant([[100.0, -100.0, 100.0],
+ [100.0, 100.0, -100.0],
+ [-100.0, 100.0, 100.0]])
+ labels = tf.constant([[1, 0, 1],
+ [1, 1, 0],
+ [0, 1, 1]])
+ loss = tf.losses.sigmoid_cross_entropy(labels, logits)
+ self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
+
+ with self.test_session():
+ self.assertAlmostEqual(loss.eval(), 0.0, 3)
+
+ def testSigmoidLabelSmoothingCorrect(self):
+ with self.test_session():
+ logits = tf.constant([[100.0, -100.0, -100.0]])
+ labels = tf.constant([[1, 0, 1]])
+ # Sigmoid cross entropy loss is:
+ # max(x,0) - x*z + log(1 + exp(-abs(x)))
+ # The new labels are:
+ # z' = z * (1 - L) + 0.5 L
+ # 1 -> 1 - 0.5 L
+ # 0 -> 0.5 L
+ # here we expect:
+ # 1/3 * (100 - 100 * (1 - 0.5 L) + 0
+ # + 0 + 100 * (0.5 L) + 0
+ # + 0 + 100 * (1 - 0.5 L) + 0)
+ # = 1/3 * (100 + 50 L)
+ label_smoothing = 0.1
+ loss = tf.losses.sigmoid_cross_entropy(
+ labels, logits, label_smoothing=label_smoothing)
+ self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
+ expected_value = (100.0 + 50.0 * label_smoothing) / 3.0
+ self.assertAlmostEqual(loss.eval(), expected_value, 3)
+
+ def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
+ with self.test_session():
+ label_smoothing = 0.1
+ sigmoid_logits = tf.constant([[100.0, -100.0, -100.0]])
+ sigmoid_labels = tf.constant([[1, 0, 1]])
+ sigmoid_loss = tf.losses.sigmoid_cross_entropy(
+ sigmoid_labels, sigmoid_logits, label_smoothing=label_smoothing)
+
+ softmax_logits = tf.constant([[0.0, 100.0], [100.0, 0.0], [100.0, 0.0]])
+ softmax_labels = tf.constant([[0, 1], [1, 0], [0, 1]])
+ softmax_loss = tf.losses.softmax_cross_entropy(
+ softmax_labels, softmax_logits, label_smoothing=label_smoothing)
+ self.assertAlmostEqual(sigmoid_loss.eval(), softmax_loss.eval(), 3)
+
+
+class LogLossTest(tf.test.TestCase):
+
+ def setUp(self):
+ predictions = np.asarray([.9, .2, .2, .8, .4, .6]).reshape((2, 3))
+ labels = np.asarray([1.0, 0.0, 1.0, 1.0, 0.0, 0.0]).reshape((2, 3))
+
+ self._np_predictions = predictions
+ self._np_labels = labels
+
+ epsilon = 1e-7
+ self._expected_losses = np.multiply(
+ labels, np.log(predictions + epsilon)) + np.multiply(
+ 1 - labels, np.log(1 - predictions + epsilon))
+
+ self._predictions = tf.constant(predictions)
+ self._labels = tf.constant(labels)
+
+ def testValueErrorThrownWhenWeightIsNone(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.losses.log_loss(self._labels, self._labels, weights=None)
+
+ def testAllCorrectNoLossWeight(self):
+ loss = tf.losses.log_loss(self._labels, self._labels)
+ with self.test_session():
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+ def testAllCorrectNoLossWeightWithPlaceholder(self):
+ tf_predictions = tf.placeholder(tf.float32, shape=self._np_labels.shape)
+ loss = tf.losses.log_loss(self._labels, tf_predictions)
+ with self.test_session():
+ self.assertAlmostEqual(0.0, loss.eval(feed_dict={
+ tf_predictions: self._np_labels}), 3)
+
+ def testNonZeroLoss(self):
+ loss = tf.losses.log_loss(self._labels, self._predictions)
+ with self.test_session():
+ self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
+ loss.eval(), 3)
+
+ def testNonZeroLossWithPythonScalarWeight(self):
+ weights = 2.3
+ loss = tf.losses.log_loss(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
+ loss.eval(), 3)
+
+ def testNonZeroLossWithScalarTensorWeight(self):
+ weights = 2.3
+ loss = tf.losses.log_loss(
+ self._labels, self._predictions, tf.constant(weights))
+ with self.test_session():
+ self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
+ loss.eval(), 3)
+
+ def testNonZeroLossWithScalarTensorWeightAndPlaceholder(self):
+ tf_predictions = tf.placeholder(tf.float32,
+ shape=self._np_predictions.shape)
+ weights = 2.3
+ loss = tf.losses.log_loss(
+ self._labels, tf_predictions, tf.constant(weights))
+ with self.test_session() as sess:
+ loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
+ self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
+ loss, 3)
+
+ def testNonZeroLossWithScalarTensorWeightAndPlaceholderWithRankOnly(self):
+ tf_predictions = tf.placeholder(tf.float32, shape=[None, None])
+ weights = 2.3
+ loss = tf.losses.log_loss(
+ self._labels, tf_predictions, tf.constant(weights))
+ with self.test_session() as sess:
+ loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
+ self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
+ loss, 3)
+
+ def testNonZeroLossWithOneDimBatchSpecificWeights(self):
+ weights = tf.constant([1.2, 3.4], shape=[2])
+ expected_losses = np.multiply(
+ self._expected_losses,
+ np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
+ loss = tf.losses.log_loss(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(-np.sum(expected_losses) / 6.0,
+ loss.eval(), 3)
+
+ def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
+ weights = tf.constant([1.2, 0], shape=[2])
+ expected_losses = np.multiply(
+ self._expected_losses,
+ np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape((2, 3)))
+ loss = tf.losses.log_loss(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(-np.sum(expected_losses) / 3.0,
+ loss.eval(), 3)
+
+ def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
+ weights = tf.constant([1.2, 0], shape=[2, 1])
+ expected_losses = np.multiply(
+ self._expected_losses,
+ np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape((2, 3)))
+ loss = tf.losses.log_loss(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(-np.sum(expected_losses) / 3.0,
+ loss.eval(), 3)
+
+ def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
+ weights = tf.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.losses.log_loss(self._labels, self._predictions, weights)
+
+ def testNonZeroLossWithMeasurementSpecificWeights(self):
+ weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
+ expected_losses = np.multiply(self._expected_losses, weights)
+
+ loss = tf.losses.log_loss(
+ self._labels,
+ self._predictions,
+ tf.constant(weights, shape=(2, 3)))
+ with self.test_session():
+ self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3)
+
+ def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
+ weights = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))
+ expected_losses = np.multiply(self._expected_losses, weights)
+
+ tf_predictions = tf.placeholder(tf.float32, shape=[2, 3])
+ loss = tf.losses.log_loss(
+ self._labels,
+ tf_predictions,
+ tf.constant(weights, shape=(2, 3)))
+
+ with self.test_session() as sess:
+ loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
+ self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3)
+
+ def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
+ weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
+ expected_losses = np.multiply(self._expected_losses, weights)
+
+ loss = tf.losses.log_loss(
+ self._labels,
+ self._predictions,
+ tf.constant(weights, shape=(2, 3)))
+ with self.test_session():
+ self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3)
+
+ def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
+ weights = np.array([0, 0, 0, 0, 0, 2]).reshape((2, 3))
+ expected_losses = np.multiply(self._expected_losses, weights)
+
+ tf_predictions = tf.placeholder(tf.float32, shape=[2, 3])
+ tf_weights = tf.constant(weights, shape=(2, 3))
+ loss = tf.losses.log_loss(self._labels, tf_predictions, tf_weights)
+
+ with self.test_session() as sess:
+ loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
+ self.assertAlmostEqual(-np.sum(expected_losses), loss, 3)
+
+ def testLossWithSampleSpecificWeightsAllZero(self):
+ tf_weights = tf.zeros(shape=(2, 3))
+ loss = tf.losses.log_loss(
+ self._labels, self._predictions, tf_weights)
+ with self.test_session():
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+
+class HingeLossTest(tf.test.TestCase):
+
+ def testIncompatibleShapes(self):
+ with self.test_session():
+ logits = tf.constant([[-1.0], [2.1]])
+ labels = tf.constant([0.0, 1.0])
+ with self.assertRaises(ValueError):
+ _ = tf.losses.hinge_loss(labels, logits).eval()
+
+ def testAllOutsideMargin(self):
+ with self.test_session():
+ logits = tf.constant([1.2, -1.4, -1.0, 2.1])
+ labels = tf.constant([1.0, 0.0, 0.0, 1.0])
+ loss = tf.losses.hinge_loss(labels, logits)
+ self.assertAllClose(loss.eval(), 0.0, atol=1e-3)
+
+ def testSomeInsideMargin(self):
+ with self.test_session():
+ logits = tf.constant([[-0.7], [-1.4], [1.4], [0.6]])
+ labels = tf.constant([[0.0], [0.0], [1.0], [1.0]])
+ loss = tf.losses.hinge_loss(labels, logits)
+ # Examples 1 and 4 are on the correct side of the hyperplane but within
+ # the margin so they incur some (small) loss.
+ self.assertAllClose(loss.eval(), 0.175, atol=1e-3)
+
+ def testSomeMisclassified(self):
+ with self.test_session():
+ logits = tf.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
+ labels = tf.constant([[[1.0], [0.0], [0.0], [1.0]]])
+ loss = tf.losses.hinge_loss(labels, logits)
+ # Examples 2 and 4 are on the wrong side of the hyperplane so they incur
+ # some (fairly large) loss.
+ self.assertAllClose(loss.eval(), 0.875, atol=1e-3)
+
+
+class MeanSquaredErrorTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._predictions = tf.constant([4, 8, 12, 8, 1, 3], shape=(2, 3))
+ self._labels = tf.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
+
+ def testValueErrorThrownWhenWeightIsNone(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.losses.mean_squared_error(
+ self._predictions, self._predictions, weights=None)
+
+ def testAllCorrectNoLossWeight(self):
+ loss = tf.losses.mean_squared_error(
+ self._predictions, self._predictions)
+ with self.test_session():
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+ def testNonZeroLoss(self):
+ loss = tf.losses.mean_squared_error(
+ self._labels, self._predictions)
+ with self.test_session():
+ self.assertAlmostEqual(49.5, loss.eval(), 3)
+
+ def testNonZeroLossWithPythonScalarWeight(self):
+ weights = 2.3
+ loss = tf.losses.mean_squared_error(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
+
+ def testNonZeroLossWithScalarTensorWeight(self):
+ weights = 2.3
+ loss = tf.losses.mean_squared_error(
+ self._labels, self._predictions, tf.constant(weights))
+ with self.test_session():
+ self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
+
+ def testNonZeroLossWithOneDimBatchSpecificWeights(self):
+ weights = tf.constant([1.2, 3.4], shape=[2,])
+ loss = tf.losses.mean_squared_error(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
+
+ def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
+ weights = tf.constant([1.2, 3.4], shape=[2, 1])
+ loss = tf.losses.mean_squared_error(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
+
+ def testNonZeroLossWithSampleSpecificWeights(self):
+ weights = tf.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
+ loss = tf.losses.mean_squared_error(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
+
+ def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
+ weights = tf.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
+ loss = tf.losses.mean_squared_error(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(18.0, loss.eval(), 3)
+
+ def testLossWithSampleSpecificWeightsAllZero(self):
+ weights = tf.zeros((2, 3))
+ loss = tf.losses.mean_squared_error(
+ self._labels, self._predictions, weights)
+ with self.test_session():
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+
+class MeanPairwiseSquaresErrorTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._predictions = np.array([[4, 8, 12],
+ [8, 1, 3]])
+ self._labels = np.array([[1, 9, 2],
+ [-5, -5, 7]])
+
+ batch_size, dims = self._labels.shape
+
+ # Compute the expected loss 'manually'.
+ total = np.zeros((batch_size, 1))
+ for b in range(batch_size):
+ for i in range(dims):
+ for j in range(dims):
+ x = self._predictions[b, i].item() - self._predictions[b, j].item()
+ y = self._labels[b, i].item() - self._labels[b, j].item()
+ tmp = (x-y) * (x-y)
+ total[b] += tmp
+
+ self._expected_losses = np.divide(total, 9.0)
+
+ def testValueErrorThrownWhenWeightIsNone(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.losses.mean_pairwise_squared_error(
+ predictions=tf.constant(self._labels),
+ labels=tf.constant(self._labels),
+ weights=None)
+
+ def testAllCorrectNoLossWeight(self):
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions=tf.constant(self._labels),
+ labels=tf.constant(self._labels))
+ with self.test_session():
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+ def testNonZeroLoss(self):
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels))
+ with self.test_session():
+ self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3)
+
+ def testGradientWithZeroWeight(self):
+ with tf.Graph().as_default():
+ tf.set_random_seed(0)
+
+ inputs = tf.ones((2, 3))
+ weights = tf.get_variable('weights',
+ shape=[3, 4],
+ initializer=tf.truncated_normal_initializer())
+ predictions = tf.matmul(inputs, weights)
+
+ optimizer = tf.train.MomentumOptimizer(learning_rate=0.001, momentum=0.9)
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions,
+ predictions,
+ 0)
+
+ gradients_to_variables = optimizer.compute_gradients(loss)
+
+ init_op = tf.global_variables_initializer()
+
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for grad, _ in gradients_to_variables:
+ np_grad = sess.run(grad)
+ self.assertFalse(np.isnan(np_grad).any())
+
+ def testNonZeroLossWithPythonScalarWeight(self):
+ weights = 2.3
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ weights=weights)
+ with self.test_session():
+ self.assertAlmostEqual(weights * np.sum(self._expected_losses),
+ loss.eval(), 3)
+
+ def testNonZeroLossWithScalarTensorWeight(self):
+ weights = 2.3
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ weights=tf.constant(weights))
+ with self.test_session():
+ self.assertAlmostEqual(weights * np.sum(self._expected_losses),
+ loss.eval(), 3)
+
+ def testNonZeroLossWithScalarZeroWeight(self):
+ weights = 0
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ weights=tf.constant(weights))
+ with self.test_session():
+ self.assertAlmostEqual(0, loss.eval(), 3)
+
+ def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self):
+ weights = 2.3
+ tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape)
+ tf_labels = tf.placeholder(tf.float32, shape=self._labels.shape)
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions=tf_predictions,
+ labels=tf_labels,
+ weights=tf.constant(weights))
+ with self.test_session() as sess:
+ loss = sess.run(loss, feed_dict={
+ tf_predictions: self._predictions,
+ tf_labels: self._labels,
+ })
+ self.assertAlmostEqual(weights * np.sum(self._expected_losses), loss, 3)
+
+ def testNonZeroLossWithOneDimBatchSpecificWeights(self):
+ weights = np.asarray([2.0, 1.0]).reshape((2, 1))
+ expected_losses = np.multiply(weights, self._expected_losses)
+
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ weights=tf.constant(weights, shape=[2]))
+ with self.test_session():
+ self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3)
+
+ def testZeroLossWithOneDimBatchZeroWeights(self):
+ weights = np.asarray([0.0, 0.0]).reshape((2, 1))
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ weights=tf.constant(weights, shape=[2]))
+ with self.test_session():
+ self.assertAlmostEqual(0, loss.eval(), 3)
+
+ def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self):
+ weights = np.asarray([1.2, 3.4]).reshape((2, 1))
+ expected_losses = np.multiply(weights, self._expected_losses)
+
+ tf_predictions = tf.placeholder(tf.float32, shape=self._predictions.shape)
+ tf_labels = tf.placeholder(tf.int32, shape=self._labels.shape)
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions=tf_predictions,
+ labels=tf_labels,
+ weights=tf.constant(weights, shape=[2]))
+
+ with self.test_session() as sess:
+ loss = sess.run(loss, feed_dict={
+ tf_predictions: self._predictions,
+ tf_labels: self._labels,
+ })
+ self.assertAlmostEqual(np.sum(expected_losses), loss, 3)
+
+ def testLossWithAllZeroBatchSpecificWeights(self):
+ weights = np.zeros((2, 1))
+ loss = tf.losses.mean_pairwise_squared_error(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ weights=tf.constant(weights, shape=[2]))
+ with self.test_session():
+ self.assertAlmostEqual(0.0, loss.eval(), 3)
+
+
+class CosineDistanceLossTest(tf.test.TestCase):
+
+ def setUp(self):
+ self._predictions = np.asarray([[1, 0, 0], # Batch 1
+ [0, 0, -1],
+ [1, 0, 0], # Batch 2
+ [1, 0, 0],
+ [0, 0, -1], # Batch 3
+ [1, 0, 0]]).reshape((3, 2, 3))
+
+ self._labels = np.asarray([[1, 0, 0],
+ [0, 0, 1],
+ [0, 1, 0],
+ [1, 0, 0],
+ [0, 0, 1],
+ [0, 1, 0]]).reshape((3, 2, 3))
+
+ def testValueErrorThrownWhenWeightIsNone(self):
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.losses.cosine_distance(
+ predictions=tf.constant(self._labels),
+ labels=tf.constant(self._labels),
+ dim=2,
+ weights=None)
+
+ def testAllCorrectNoWeights(self):
+ loss = tf.losses.cosine_distance(
+ predictions=tf.constant(self._labels),
+ labels=tf.constant(self._labels),
+ dim=2)
+ with self.test_session():
+ self.assertAlmostEqual(0, loss.eval(), 5)
+
+ def testPartiallyCorrectWithIntegerValues(self):
+ loss = tf.losses.cosine_distance(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ dim=2)
+ with self.test_session():
+ self.assertAlmostEqual(1, loss.eval(), 5)
+
+ def testPartiallyCorrectFloatingPointValues(self):
+ predictions = np.matrix((
+ '0.819031913261206 0.567041924552012 0.087465312324590;'
+ '-0.665139432070255 -0.739487441769973 -0.103671883216994;'
+ '0.707106781186548 -0.707106781186548 0'))
+ labels = np.matrix((
+ '0.819031913261206 0.567041924552012 0.087465312324590;'
+ '0.665139432070255 0.739487441769973 0.103671883216994;'
+ '0.707106781186548 0.707106781186548 0'))
+
+ tf_preds = tf.constant(predictions, shape=(3, 1, 3), dtype=tf.float32)
+ tf_labels = tf.constant(labels, shape=(3, 1, 3), dtype=tf.float32)
+ loss = tf.losses.cosine_distance(tf_labels, tf_preds, dim=2)
+
+ with self.test_session():
+ self.assertAlmostEqual(1.0, loss.eval(), 5)
+
+ def testSampleSpecificWeights(self):
+ loss = tf.losses.cosine_distance(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ dim=2,
+ weights=tf.constant([1, 0, 0]))
+ with self.test_session():
+ self.assertEqual(1.0, loss.eval())
+
+ def testMeasurementSpecificWeights(self):
+ loss = tf.losses.cosine_distance(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ dim=2,
+ weights=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2)))
+ with self.test_session():
+ self.assertEqual(3.0 / 4.0, loss.eval())
+
+ def testValueErrorThrownWithShapelessPlaceholder(self):
+ tf_predictions = tf.placeholder(tf.float32)
+ with self.test_session():
+ with self.assertRaises(ValueError):
+ tf.losses.cosine_distance(
+ predictions=tf_predictions,
+ labels=tf.constant(self._labels),
+ dim=2,
+ weights=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2)))
+
+ def testMeasurementSpecificWeightsWithPlaceholderWithShape(self):
+ tf_predictions = tf.placeholder(tf.float32, shape=self._labels.shape)
+ loss = tf.losses.cosine_distance(
+ predictions=tf_predictions,
+ labels=tf.constant(self._labels),
+ dim=2,
+ weights=tf.constant([1, 0, 0, 1, 1, 1], shape=(3, 2)))
+ with self.test_session() as sess:
+ loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
+ self.assertEqual(3.0 / 4.0, loss)
+
+ def testZeroLossWhenAllSampleSpecificWeightsAreZero(self):
+ loss = tf.losses.cosine_distance(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ dim=2,
+ weights=tf.zeros((3,)))
+ with self.test_session():
+ self.assertEqual(0, loss.eval())
+
+ def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
+ loss = tf.losses.cosine_distance(
+ predictions=tf.constant(self._predictions),
+ labels=tf.constant(self._labels),
+ dim=2,
+ weights=tf.zeros((3, 2)))
+ with self.test_session():
+ self.assertEqual(0, loss.eval())
+
+
+class AddLossTest(tf.test.TestCase):
+
+ def testNoCollectLossesBatch2(self):
+ logits = tf.constant([[1.2, 0.4, -1.0, -1.1]] * 2)
+ labels = tf.constant([[1.0, 0.0, 0.0, 1.0]] * 2)
+ self.assertFalse(tf.losses.get_losses())
+ tf.losses.absolute_difference(logits, labels, loss_collection=None)
+ tf.losses.log_loss(logits, labels, loss_collection=None)
+ tf.losses.mean_squared_error(logits, labels, loss_collection=None)
+ tf.losses.sigmoid_cross_entropy(logits, labels, loss_collection=None)
+ tf.losses.softmax_cross_entropy(logits, labels, loss_collection=None)
+ self.assertFalse(tf.losses.get_losses())
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/python/ops/losses/BUILD b/tensorflow/python/ops/losses/BUILD
new file mode 100644
index 0000000000..4a2edd99b2
--- /dev/null
+++ b/tensorflow/python/ops/losses/BUILD
@@ -0,0 +1,39 @@
+package(
+ default_visibility = ["//tensorflow:internal"],
+ features = [
+ "-layering_check",
+ "-parse_headers",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+py_library(
+ name = "losses",
+ srcs = [
+ "__init__.py",
+ "losses.py",
+ "util.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:nn_ops",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/python/ops/losses/__init__.py b/tensorflow/python/ops/losses/__init__.py
new file mode 100644
index 0000000000..3b0d0d8e5a
--- /dev/null
+++ b/tensorflow/python/ops/losses/__init__.py
@@ -0,0 +1,21 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Loss functions and helpers to manipulate them.
+"""
+
+# pylint: disable=wildcard-import
+from tensorflow.python.ops.losses.losses import *
+from tensorflow.python.ops.losses.util import *
diff --git a/tensorflow/python/ops/losses/losses.py b/tensorflow/python/ops/losses/losses.py
new file mode 100644
index 0000000000..e6c2a558b3
--- /dev/null
+++ b/tensorflow/python/ops/losses/losses.py
@@ -0,0 +1,588 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Loss operations for use in neural networks.
+
+Note: All the losses are added to the `GraphKeys.LOSSES` collection by default.
+
+@@absolute_difference
+@@compute_weighted_loss
+@@cosine_distance
+@@hinge_loss
+@@log_loss
+@@mean_pairwise_squared_error
+@@mean_squared_error
+@@sigmoid_cross_entropy
+@@softmax_cross_entropy
+@@sparse_softmax_cross_entropy
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops.losses import util
+
+
+def _scale_losses(losses, weights):
+ """Computes the scaled loss.
+
+ Args:
+ losses: A `Tensor` of size [batch_size, d1, ... dN].
+ weights: A `Tensor` of size [1], [batch_size] or [batch_size, d1, ... dN].
+ The `losses` are reduced (tf.reduce_sum) until its dimension matches
+ that of `weights` at which point the reduced `losses` are element-wise
+ multiplied by `weights` and a final reduce_sum is computed on the result.
+ Conceptually, this operation is equivalent to broadcasting (tiling)
+ `weights` to be the same size as `losses`, performing an element-wise
+ multiplication, and summing the result.
+
+ Returns:
+ A scalar tf.float32 `Tensor` whose value represents the sum of the scaled
+ `losses`.
+ """
+ # First, compute the sum of the losses over all elements:
+ start_index = max(0, weights.get_shape().ndims)
+ reduction_indices = list(range(start_index, losses.get_shape().ndims))
+ reduced_losses = math_ops.reduce_sum(losses,
+ reduction_indices=reduction_indices)
+ reduced_losses = math_ops.mul(reduced_losses, weights)
+ return math_ops.reduce_sum(reduced_losses)
+
+
+def _safe_div(numerator, denominator, name="value"):
+ """Computes a safe divide which returns 0 if the denominator is zero.
+
+ Note that the function contains an additional conditional check that is
+ necessary for avoiding situations where the loss is zero causing NaNs to
+ creep into the gradient computation.
+
+ Args:
+ numerator: An arbitrary `Tensor`.
+ denominator: A `Tensor` whose shape matches `numerator` and whose values are
+ assumed to be non-negative.
+ name: An optional name for the returned op.
+
+ Returns:
+ The element-wise value of the numerator divided by the denominator.
+ """
+ return math_ops.select(
+ math_ops.greater(denominator, 0),
+ math_ops.div(numerator, math_ops.select(
+ math_ops.equal(denominator, 0),
+ array_ops.ones_like(denominator), denominator)),
+ array_ops.zeros_like(numerator),
+ name=name)
+
+
+def _safe_mean(losses, num_present):
+ """Computes a safe mean of the losses.
+
+ Args:
+ losses: A tensor whose elements contain individual loss measurements.
+ num_present: The number of measurable losses in the tensor.
+
+ Returns:
+ A scalar representing the mean of the losses. If `num_present` is zero,
+ then zero is returned.
+ """
+ total_loss = math_ops.reduce_sum(losses)
+ return _safe_div(total_loss, num_present)
+
+
+def _num_present(losses, weights, per_batch=False):
+ """Computes the number of elements in the loss function induced by `weights`.
+
+ A given weights tensor induces different numbers of usable elements in the
+ `losses` tensor. The `weights` tensor is broadcast across `losses` for all
+ possible dimensions. For example, if `losses` is a tensor of dimension
+ [4, 5, 6, 3] and `weights` is a tensor of size [4, 5], then `weights` is, in
+ effect, tiled to match the size of `losses`. Following this effective tile,
+ the total number of present elements is the number of non-zero weights.
+
+ Args:
+ losses: A tensor of size [batch_size, d1, ... dN].
+ weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
+ per_batch: Whether to return the number of elements per batch or as a sum
+ total.
+
+ Returns:
+ The number of present (non-zero) elements in the losses tensor. If
+ `per_batch` is True, the value is returned as a tensor of size
+ [batch_size]. Otherwise, a single scalar tensor is returned.
+ """
+ # If weights is a scalar, its easy to compute:
+ if weights.get_shape().ndims == 0:
+ batch_size = array_ops.reshape(array_ops.slice(array_ops.shape(losses),
+ [0], [1]), [])
+ num_per_batch = math_ops.div(math_ops.to_float(array_ops.size(losses)),
+ math_ops.to_float(batch_size))
+ num_per_batch = math_ops.select(math_ops.equal(weights, 0),
+ 0.0, num_per_batch)
+ num_per_batch = math_ops.mul(array_ops.ones(
+ array_ops.reshape(batch_size, [1])), num_per_batch)
+ return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
+
+ # First, count the number of nonzero weights:
+ if weights.get_shape().ndims >= 1:
+ reduction_indices = list(range(1, weights.get_shape().ndims))
+ num_nonzero_per_batch = math_ops.reduce_sum(
+ math_ops.to_float(math_ops.not_equal(weights, 0)),
+ reduction_indices=reduction_indices)
+
+ # Next, determine the number of elements that weight would broadcast to:
+ broadcast_dims = array_ops.slice(array_ops.shape(losses),
+ [weights.get_shape().ndims], [-1])
+ num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims))
+
+ num_per_batch = math_ops.mul(num_nonzero_per_batch, num_to_broadcast)
+ return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
+
+
+def compute_weighted_loss(
+ losses, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES):
+ """Computes the weighted loss.
+
+ Args:
+ losses: A tensor of size [batch_size, d1, ... dN].
+ weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
+ scope: the scope for the operations performed in computing the loss.
+ loss_collection: the loss will be added to these collections.
+
+ Returns:
+ A scalar `Tensor` that returns the weighted loss.
+
+ Raises:
+ ValueError: If `weights` is `None` or the shape is not compatible with
+ `losses`, or if the number of dimensions (rank) of either `losses` or
+ `weights` is missing.
+ """
+ with ops.name_scope(scope, "weighted_loss", [losses, weights]):
+ losses = ops.convert_to_tensor(losses)
+ input_dtype = losses.dtype
+ losses = math_ops.to_float(losses)
+ weights = math_ops.to_float(ops.convert_to_tensor(weights))
+
+ if losses.get_shape().ndims is None:
+ raise ValueError("losses.get_shape().ndims cannot be None")
+ weights_shape = weights.get_shape()
+ if weights_shape.ndims is None:
+ raise ValueError("weight.get_shape().ndims cannot be None")
+
+ if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1):
+ weights = array_ops.squeeze(weights, [-1])
+
+ total_loss = _scale_losses(losses, weights)
+ num_present = _num_present(losses, weights)
+ mean_loss = _safe_mean(total_loss, num_present)
+ # convert the result back to the input type
+ mean_loss = math_ops.cast(mean_loss, input_dtype)
+ util.add_loss(mean_loss, loss_collection)
+ return mean_loss
+
+
+def absolute_difference(
+ labels, predictions, weights=1.0, scope=None,
+ loss_collection=ops.GraphKeys.LOSSES):
+ """Adds an Absolute Difference loss to the training procedure.
+
+ `weights` acts as a coefficient for the loss. If a scalar is provided, then
+ the loss is simply scaled by the given value. If `weights` is a tensor of
+ size [batch_size], then the total loss for each sample of the batch is
+ rescaled by the corresponding element in the `weight` vector. If the shape of
+ `weight` matches the shape of `predictions`, then the loss of each
+ measurable element of `predictions` is scaled by the corresponding value of
+ `weight`.
+
+ Args:
+ labels: The ground truth output tensor, same dimensions as 'predictions'.
+ predictions: The predicted outputs.
+ weights: Coefficients for the loss a scalar, a tensor of shape
+ [batch_size] or a tensor whose shape matches `predictions`.
+ scope: The scope for the operations performed in computing the loss.
+ loss_collection: collection to which this loss will be added.
+
+ Returns:
+ A scalar `Tensor` representing the loss value.
+
+ Raises:
+ ValueError: If the shape of `predictions` doesn't match that of `labels` or
+ if the shape of `weight` is invalid.
+ """
+ with ops.name_scope(scope, "absolute_difference",
+ [predictions, labels, weights]) as scope:
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+ predictions = math_ops.to_float(predictions)
+ labels = math_ops.to_float(labels)
+ losses = math_ops.abs(math_ops.sub(predictions, labels))
+ return compute_weighted_loss(losses, weights, scope, loss_collection)
+
+
+def cosine_distance(
+ labels, predictions, dim=None, weights=1.0, scope=None,
+ loss_collection=ops.GraphKeys.LOSSES):
+ """Adds a cosine-distance loss to the training procedure.
+
+ Note that the function assumes that `predictions` and `labels` are already
+ unit-normalized.
+
+ Args:
+ labels: A `Tensor` whose shape matches 'predictions'
+ predictions: An arbitrary matrix.
+ dim: The dimension along which the cosine distance is computed.
+ weights: Coefficients for the loss a scalar, a tensor of shape
+ [batch_size] or a tensor whose shape matches `predictions`.
+ scope: The scope for the operations performed in computing the loss.
+ loss_collection: collection to which this loss will be added.
+
+ Returns:
+ A scalar `Tensor` representing the loss value.
+
+ Raises:
+ ValueError: If `predictions` shape doesn't match `labels` shape, or
+ `weights` is `None`.
+ """
+ if dim is None:
+ raise ValueError("`dim` cannot be None.")
+ with ops.name_scope(scope, "cosine_distance_loss",
+ [predictions, labels, weights]) as scope:
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+
+ predictions = math_ops.to_float(predictions)
+ labels = math_ops.to_float(labels)
+
+ radial_diffs = math_ops.mul(predictions, labels)
+ losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,])
+ return compute_weighted_loss(losses, weights, scope, loss_collection)
+
+
+def hinge_loss(labels, logits, weights=1.0, scope=None,
+ loss_collection=ops.GraphKeys.LOSSES):
+ """Adds a hinge loss to the training procedure.
+
+ Args:
+ labels: The ground truth output tensor. Its shape should match the shape of
+ logits. The values of the tensor are expected to be 0.0 or 1.0.
+ logits: The logits, a float tensor.
+ weights: Coefficients for the loss a scalar, a tensor of shape
+ [batch_size] or a tensor whose shape matches `predictions`.
+ scope: The scope for the operations performed in computing the loss.
+ loss_collection: collection to which the loss will be added.
+
+ Returns:
+ A scalar `Tensor` of the loss value.
+
+ Raises:
+ ValueError: If the shapes of `logits` and `labels` don't match.
+ """
+ with ops.name_scope(scope, "hinge_loss", [logits, labels]) as scope:
+ logits.get_shape().assert_is_compatible_with(labels.get_shape())
+ # We first need to convert binary labels to -1/1 labels (as floats).
+ labels = math_ops.to_float(labels)
+ all_ones = array_ops.ones_like(labels)
+ labels = math_ops.sub(2 * labels, all_ones)
+ losses = nn_ops.relu(math_ops.sub(all_ones, math_ops.mul(labels, logits)))
+ return compute_weighted_loss(losses, weights, scope, loss_collection)
+
+
+def log_loss(labels, predictions, weights=1.0, epsilon=1e-7, scope=None,
+ loss_collection=ops.GraphKeys.LOSSES):
+ """Adds a Log Loss term to the training procedure.
+
+ `weight` acts as a coefficient for the loss. If a scalar is provided, then the
+ loss is simply scaled by the given value. If `weight` is a tensor of size
+ [batch_size], then the total loss for each sample of the batch is rescaled
+ by the corresponding element in the `weight` vector. If the shape of
+ `weight` matches the shape of `predictions`, then the loss of each
+ measurable element of `predictions` is scaled by the corresponding value of
+ `weight`.
+
+ Args:
+ labels: The ground truth output tensor, same dimensions as 'predictions'.
+ predictions: The predicted outputs.
+ weights: Coefficients for the loss a scalar, a tensor of shape
+ [batch_size] or a tensor whose shape matches `predictions`.
+ epsilon: A small increment to add to avoid taking a log of zero.
+ scope: The scope for the operations performed in computing the loss.
+ loss_collection: collection to which the loss will be added.
+
+ Returns:
+ A scalar `Tensor` representing the loss value.
+
+ Raises:
+ ValueError: If the shape of `predictions` doesn't match that of `labels` or
+ if the shape of `weight` is invalid.
+ """
+ with ops.name_scope(scope, "log_loss",
+ [predictions, labels, weights]) as scope:
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+ predictions = math_ops.to_float(predictions)
+ labels = math_ops.to_float(labels)
+ losses = -math_ops.mul(
+ labels,
+ math_ops.log(predictions + epsilon)) - math_ops.mul(
+ (1 - labels), math_ops.log(1 - predictions + epsilon))
+ return compute_weighted_loss(losses, weights, scope, loss_collection)
+
+
+def mean_pairwise_squared_error(labels, predictions, weights=1.0, scope=None,
+ loss_collection=ops.GraphKeys.LOSSES):
+ """Adds a pairwise-errors-squared loss to the training procedure.
+
+ Unlike `mean_squared_error`, which is a measure of the differences between
+ corresponding elements of `predictions` and `labels`,
+ `mean_pairwise_squared_error` is a measure of the differences between pairs of
+ corresponding elements of `predictions` and `labels`.
+
+ For example, if `labels`=[a, b, c] and `predictions`=[x, y, z], there are
+ three pairs of differences are summed to compute the loss:
+ loss = [ ((a-b) - (x-y)).^2 + ((a-c) - (x-z)).^2 + ((b-c) - (y-z)).^2 ] / 3
+
+ Note that since the inputs are of size [batch_size, d0, ... dN], the
+ corresponding pairs are computed within each batch sample but not across
+ samples within a batch. For example, if `predictions` represents a batch of
+ 16 grayscale images of dimension [batch_size, 100, 200], then the set of pairs
+ is drawn from each image, but not across images.
+
+ `weight` acts as a coefficient for the loss. If a scalar is provided, then the
+ loss is simply scaled by the given value. If `weight` is a tensor of size
+ [batch_size], then the total loss for each sample of the batch is rescaled
+ by the corresponding element in the `weight` vector.
+
+ Args:
+ labels: The ground truth output tensor, whose shape must match the shape of
+ the `predictions` tensor.
+ predictions: The predicted outputs, a tensor of size [batch_size, d0, .. dN]
+ where N+1 is the total number of dimensions in `predictions`.
+ weights: Coefficients for the loss a scalar, a tensor of shape [batch_size]
+ or a tensor whose shape matches `predictions`.
+ scope: The scope for the operations performed in computing the loss.
+ loss_collection: collection to which the loss will be added.
+
+ Returns:
+ A scalar `Tensor` representing the loss value.
+
+ Raises:
+ ValueError: If the shape of `predictions` doesn't match that of `labels` or
+ if the shape of `weight` is invalid.
+ """
+ with ops.name_scope(scope, "mean_pairwise_squared_error",
+ [predictions, labels, weights]) as scope:
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+ predictions = math_ops.to_float(predictions)
+ labels = math_ops.to_float(labels)
+ weights = math_ops.to_float(ops.convert_to_tensor(weights))
+
+ diffs = math_ops.sub(predictions, labels)
+
+ # Need to verify here since the function doesn't use compute_weighted_loss
+ if diffs.get_shape().ndims is None:
+ raise ValueError("diffs.get_shape().ndims cannot be None")
+ if weights.get_shape().ndims is None:
+ raise ValueError("weights.get_shape().ndims cannot be None")
+
+ reduction_indices = list(range(1, diffs.get_shape().ndims))
+
+ sum_squares_diff_per_batch = math_ops.reduce_sum(
+ math_ops.square(diffs),
+ reduction_indices=reduction_indices)
+ num_present_per_batch = _num_present(diffs, weights, per_batch=True)
+
+ term1 = 2.0 * _safe_div(sum_squares_diff_per_batch,
+ num_present_per_batch)
+
+ sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
+ term2 = 2.0 * _safe_div(math_ops.square(sum_diff),
+ math_ops.square(num_present_per_batch))
+
+ loss = _scale_losses(term1 - term2, weights)
+
+ mean_loss = math_ops.select(math_ops.reduce_sum(num_present_per_batch) > 0,
+ loss,
+ array_ops.zeros_like(loss),
+ name="value")
+ util.add_loss(mean_loss, loss_collection)
+ return mean_loss
+
+
+def mean_squared_error(labels, predictions, weights=1.0, scope=None,
+ loss_collection=ops.GraphKeys.LOSSES):
+ """Adds a Sum-of-Squares loss to the training procedure.
+
+ `weight` acts as a coefficient for the loss. If a scalar is provided, then the
+ loss is simply scaled by the given value. If `weight` is a tensor of size
+ [batch_size], then the total loss for each sample of the batch is rescaled
+ by the corresponding element in the `weight` vector. If the shape of
+ `weight` matches the shape of `predictions`, then the loss of each
+ measurable element of `predictions` is scaled by the corresponding value of
+ `weight`.
+
+ Args:
+ labels: The ground truth output tensor, same dimensions as 'predictions'.
+ predictions: The predicted outputs.
+ weights: Coefficients for the loss a scalar, a tensor of shape
+ [batch_size] or a tensor whose shape matches `predictions`.
+ scope: The scope for the operations performed in computing the loss.
+ loss_collection: collection to which the loss will be added.
+
+ Returns:
+ A scalar `Tensor` representing the loss value.
+
+ Raises:
+ ValueError: If the shape of `predictions` doesn't match that of `labels` or
+ if the shape of `weight` is invalid.
+ """
+ with ops.name_scope(scope, "mean_squared_error",
+ [predictions, labels, weights]) as scope:
+ predictions.get_shape().assert_is_compatible_with(labels.get_shape())
+ predictions = math_ops.to_float(predictions)
+ labels = math_ops.to_float(labels)
+ losses = math_ops.square(math_ops.sub(predictions, labels))
+ return compute_weighted_loss(losses, weights, scope, loss_collection)
+
+
+def sigmoid_cross_entropy(
+ multi_class_labels, logits, weights=1.0, label_smoothing=0, scope=None,
+ loss_collection=ops.GraphKeys.LOSSES):
+ """Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits.
+
+ `weight` acts as a coefficient for the loss. If a scalar is provided,
+ then the loss is simply scaled by the given value. If `weight` is a
+ tensor of size [`batch_size`], then the loss weights apply to each
+ corresponding sample.
+
+ If `label_smoothing` is nonzero, smooth the labels towards 1/2:
+
+ new_multiclass_labels = multiclass_labels * (1 - label_smoothing)
+ + 0.5 * label_smoothing
+
+ Args:
+ multi_class_labels: [batch_size, num_classes] target labels in (0, 1).
+ logits: [batch_size, num_classes] logits outputs of the network .
+ weights: Coefficients for the loss. The tensor must be a scalar, a tensor of
+ shape [batch_size] or shape [batch_size, num_classes].
+ label_smoothing: If greater than 0 then smooth the labels.
+ scope: The scope for the operations performed in computing the loss.
+ loss_collection: collection to which the loss will be added.
+
+ Returns:
+ A scalar `Tensor` representing the loss value.
+
+ Raises:
+ ValueError: If the shape of `logits` doesn't match that of
+ `multi_class_labels` or if the shape of `weight` is invalid, or if
+ `weight` is None.
+ """
+ with ops.name_scope(scope, "sigmoid_cross_entropy_loss",
+ [logits, multi_class_labels, weights]) as scope:
+ logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
+
+ multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype)
+
+ if label_smoothing > 0:
+ multi_class_labels = (multi_class_labels * (1 - label_smoothing) +
+ 0.5 * label_smoothing)
+
+ losses = nn.sigmoid_cross_entropy_with_logits(logits, multi_class_labels,
+ name="xentropy")
+ return compute_weighted_loss(losses, weights, scope, loss_collection)
+
+
+def softmax_cross_entropy(
+ onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
+ loss_collection=ops.GraphKeys.LOSSES):
+ """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits.
+
+ `weight` acts as a coefficient for the loss. If a scalar is provided,
+ then the loss is simply scaled by the given value. If `weight` is a
+ tensor of size [`batch_size`], then the loss weights apply to each
+ corresponding sample.
+
+ If `label_smoothing` is nonzero, smooth the labels towards 1/num_classes:
+ new_onehot_labels = onehot_labels * (1 - label_smoothing)
+ + label_smoothing / num_classes
+
+ Args:
+ onehot_labels: [batch_size, num_classes] target one_hot_encoded labels.
+ logits: [batch_size, num_classes] logits outputs of the network .
+ weights: Coefficients for the loss. The tensor must be a scalar or a tensor
+ of shape [batch_size].
+ label_smoothing: If greater than 0 then smooth the labels.
+ scope: the scope for the operations performed in computing the loss.
+ loss_collection: collection to which the loss will be added.
+
+ Returns:
+ A scalar `Tensor` representing the loss value.
+
+ Raises:
+ ValueError: If the shape of `logits` doesn't match that of `onehot_labels`
+ or if the shape of `weight` is invalid or if `weight` is None.
+ """
+ with ops.name_scope(scope, "softmax_cross_entropy_loss",
+ [logits, onehot_labels, weights]) as scope:
+ logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
+
+ onehot_labels = math_ops.cast(onehot_labels, logits.dtype)
+
+ if label_smoothing > 0:
+ num_classes = math_ops.cast(
+ array_ops.shape(onehot_labels)[1], logits.dtype)
+ smooth_positives = 1.0 - label_smoothing
+ smooth_negatives = label_smoothing / num_classes
+ onehot_labels = onehot_labels * smooth_positives + smooth_negatives
+
+ losses = nn.softmax_cross_entropy_with_logits(logits, onehot_labels,
+ name="xentropy")
+ return compute_weighted_loss(losses, weights, scope, loss_collection)
+
+
+def sparse_softmax_cross_entropy(labels, logits, weights=1.0, scope=None,
+ loss_collection=ops.GraphKeys.LOSSES):
+ """Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`.
+
+ `weight` acts as a coefficient for the loss. If a scalar is provided,
+ then the loss is simply scaled by the given value. If `weight` is a
+ tensor of size [`batch_size`], then the loss weights apply to each
+ corresponding sample.
+
+ Args:
+ labels: [batch_size, 1] or [batch_size] target labels of dtype `int32` or
+ `int64` in the range `[0, num_classes)`.
+ logits: [batch_size, num_classes] logits outputs of the network .
+ weights: Coefficients for the loss. The tensor must be a scalar or a tensor
+ of shape [batch_size] or [batch_size, 1].
+ scope: the scope for the operations performed in computing the loss.
+ loss_collection: collection to which the loss will be added.
+
+ Returns:
+ A scalar `Tensor` representing the loss value.
+
+ Raises:
+ ValueError: If the shapes of logits, labels, and weight are incompatible, or
+ if `weight` is None.
+ """
+ with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
+ [logits, labels, weights]) as scope:
+ labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
+ weights = array_ops.squeeze(weights)
+
+ losses = nn.sparse_softmax_cross_entropy_with_logits(logits, labels,
+ name="xentropy")
+ return compute_weighted_loss(losses, weights, scope, loss_collection)
diff --git a/tensorflow/python/ops/losses/util.py b/tensorflow/python/ops/losses/util.py
new file mode 100644
index 0000000000..aaf324891f
--- /dev/null
+++ b/tensorflow/python/ops/losses/util.py
@@ -0,0 +1,88 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities for manipulating the loss collections.
+
+
+@@add_loss
+@@get_losses
+@@get_regularization_losses
+@@get_total_loss
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+
+
+def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
+ """Adds a externally defined loss to the collection of losses.
+
+ Args:
+ loss: A loss `Tensor`.
+ loss_collection: Optional collection to add the loss to.
+ """
+ if loss_collection:
+ ops.add_to_collection(loss_collection, loss)
+
+
+def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
+ """Gets the list of losses from the loss_collection.
+
+ Args:
+ scope: an optional scope for filtering the losses to return.
+ loss_collection: Optional losses collection.
+
+ Returns:
+ a list of loss tensors.
+ """
+ return ops.get_collection(loss_collection, scope)
+
+
+def get_regularization_losses(scope=None):
+ """Gets the regularization losses.
+
+ Args:
+ scope: an optional scope for filtering the losses to return.
+
+ Returns:
+ A list of loss variables.
+ """
+ return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
+
+
+def get_total_loss(add_regularization_losses=True, name="total_loss"):
+ """Returns a tensor whose value represents the total loss.
+
+ Notice that the function adds the given losses to the regularization losses.
+
+ Args:
+ add_regularization_losses: A boolean indicating whether or not to use the
+ regularization losses in the sum.
+ name: The name of the returned tensor.
+
+ Returns:
+ A `Tensor` whose value represents the total loss.
+
+ Raises:
+ ValueError: if `losses` is not iterable.
+ """
+ losses = get_losses()
+ if add_regularization_losses:
+ losses += get_regularization_losses()
+ return math_ops.add_n(losses, name=name)