diff options
author | 2017-07-17 15:11:33 -0700 | |
---|---|---|
committer | 2017-07-17 15:15:41 -0700 | |
commit | 5799c01281ff607214f755693b98d03dd0847e18 (patch) | |
tree | 4a65f30353bb2a44349d86001e0ff9d53b843a0e /tensorflow/contrib/kernel_methods | |
parent | 625ddbaf2d64d5b53551480eb51d2d1c02775d2d (diff) |
Adds multi-class (Crammer-Singer) hinge loss to tf.contrib.kernel_methods
PiperOrigin-RevId: 162277054
Diffstat (limited to 'tensorflow/contrib/kernel_methods')
-rw-r--r-- | tensorflow/contrib/kernel_methods/BUILD | 17 | ||||
-rw-r--r-- | tensorflow/contrib/kernel_methods/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/kernel_methods/python/losses.py | 135 | ||||
-rw-r--r-- | tensorflow/contrib/kernel_methods/python/losses_test.py | 206 |
4 files changed, 360 insertions, 0 deletions
diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index fccaa3abd4..ae1402b0e6 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -14,6 +14,7 @@ py_library( srcs = [ "__init__.py", "python/kernel_estimators.py", + "python/losses.py", "python/mappers/random_fourier_features.py", ], srcs_version = "PY2AND3", @@ -22,11 +23,15 @@ py_library( "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", "//tensorflow/python:platform", "//tensorflow/python:util", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@six_archive//:six", ], @@ -71,6 +76,18 @@ py_test( ], ) +py_test( + name = "losses_test", + srcs = ["python/losses_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":kernel_methods", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/kernel_methods/__init__.py b/tensorflow/contrib/kernel_methods/__init__.py index 7272e59516..0f3827d187 100644 --- a/tensorflow/contrib/kernel_methods/__init__.py +++ b/tensorflow/contrib/kernel_methods/__init__.py @@ -16,12 +16,14 @@ @@KernelLinearClassifier @@RandomFourierFeatureMapper +@@sparse_multiclass_hinge_loss """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.contrib.kernel_methods.python.kernel_estimators import KernelLinearClassifier +from tensorflow.contrib.kernel_methods.python.losses import sparse_multiclass_hinge_loss from tensorflow.contrib.kernel_methods.python.mappers.random_fourier_features import RandomFourierFeatureMapper from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py new file mode 100644 index 0000000000..208b0e1c9d --- /dev/null +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -0,0 +1,135 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of kernel-methods-related loss operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.losses import losses + + +def sparse_multiclass_hinge_loss( + labels, + logits, + weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS): + """Adds Ops for computing the multiclass hinge loss. + + The implementation is based on the following paper: + On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines + by Crammer and Singer. + link: http://jmlr.csail.mit.edu/papers/volume2/crammer01a/crammer01a.pdf + + This is a generalization of standard (binary) hinge loss. For a given instance + with correct label c*, the loss is given by: + loss = max_{c != c*} logits_c - logits_{c*} + 1. + or equivalently + loss = max_c { logits_c - logits_{c*} + I_{c != c*} } + where I_{c != c*} = 1 if c != c* and 0 otherwise. + + Args: + labels: `Tensor` of shape [batch_size] or [batch_size, 1]. Corresponds to + the ground truth. Each entry must be an index in `[0, num_classes)`. + logits: `Tensor` of shape [batch_size, num_classes] corresponding to the + unscaled logits. Its dtype should be either `float32` or `float64`. + weights: Optional (python) scalar or `Tensor`. If a non-scalar `Tensor`, its + rank should be either 1 ([batch_size]) or 2 ([batch_size, 1]). + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which the loss will be added. + reduction: Type of reduction to apply to loss. + + Returns: + Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same + shape as `labels`; otherwise, it is a scalar. + + Raises: + ValueError: If `logits`, `labels` or `weights` have invalid or inconsistent + shapes. + ValueError: If `labels` tensor has invalid dtype. + """ + + with ops.name_scope(scope, 'sparse_multiclass_hinge_loss', (logits, + labels)) as scope: + + # Check logits Tensor has valid rank. + logits_shape = logits.get_shape() + logits_rank = logits_shape.ndims + if logits_rank != 2: + raise ValueError( + 'logits should have rank 2 ([batch_size, num_classes]). Given rank is' + ' {}'.format(logits_rank)) + batch_size, num_classes = logits_shape[0].value, logits_shape[1].value + logits = math_ops.to_float(logits) + + # Check labels have valid type. + if labels.dtype != dtypes.int32 and labels.dtype != dtypes.int64: + raise ValueError( + 'Invalid dtype for labels: {}. Acceptable dtypes: int32 and int64'. + format(labels.dtype)) + + # Check labels and weights have valid ranks and are consistent. + labels_rank = labels.get_shape().ndims + if labels_rank not in [1, 2]: + raise ValueError( + 'labels should have rank 1 ([batch_size]) or 2 ([batch_size, 1]). ' + 'Given rank is {}'.format(labels_rank)) + with ops.control_dependencies([ + check_ops.assert_less(labels, math_ops.cast(num_classes, labels.dtype)) + ]): + labels = array_ops.reshape(labels, shape=[-1]) + + weights = ops.convert_to_tensor(weights) + weights_rank = weights.get_shape().ndims + if weights_rank not in [0, 1, 2]: + raise ValueError( + 'non-scalar weights should have rank 1 ([batch_size]) or 2 ' + '([batch_size, 1]). Given rank is {}'.format(labels_rank)) + + if weights_rank > 0: + weights = array_ops.reshape(weights, shape=[-1]) + # Check weights and labels have the same number of elements. + weights.get_shape().assert_is_compatible_with(labels.get_shape()) + + # Compute the logits tensor corresponding to the correct class per instance. + example_indices = array_ops.reshape( + math_ops.range(batch_size), shape=[batch_size, 1]) + indices = array_ops.concat( + [ + example_indices, + array_ops.reshape( + math_ops.cast(labels, example_indices.dtype), + shape=[batch_size, 1]) + ], + axis=1) + label_logits = array_ops.reshape( + array_ops.gather_nd(params=logits, indices=indices), + shape=[batch_size, 1]) + + one_cold_labels = array_ops.one_hot( + indices=labels, depth=num_classes, on_value=0.0, off_value=1.0) + margin = logits - label_logits + one_cold_labels + margin = nn_ops.relu(margin) + loss = math_ops.reduce_max(margin, axis=1) + return losses.compute_weighted_loss( + loss, weights, scope, loss_collection, reduction=reduction) diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py new file mode 100644 index 0000000000..8a1a5ffe56 --- /dev/null +++ b/tensorflow/contrib/kernel_methods/python/losses_test.py @@ -0,0 +1,206 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for third_party.tensorflow.contrib.kernel_methods.python.losses.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kernel_methods.python import losses +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test + + +class SparseMulticlassHingeLossTest(test.TestCase): + + def testInvalidLogitsShape(self): + """An error is raised when logits have invalid shape.""" + with self.test_session(): + logits = constant_op.constant([-1.0, 2.1], shape=(2,)) + labels = constant_op.constant([0, 1]) + with self.assertRaises(ValueError): + _ = losses.sparse_multiclass_hinge_loss(labels, logits) + + def testInvalidLabelsShape(self): + """An error is raised when labels have invalid shape.""" + with self.test_session(): + logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) + labels = constant_op.constant([1, 0], shape=(1, 1, 2)) + with self.assertRaises(ValueError): + _ = losses.sparse_multiclass_hinge_loss(labels, logits) + + def testInvalidWeightsShape(self): + """An error is raised when weights have invalid shape.""" + with self.test_session(): + logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) + labels = constant_op.constant([1, 0], shape=(2,)) + weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1)) + with self.assertRaises(ValueError): + _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights) + + def testInvalidLabelsDtype(self): + """An error is raised when labels have invalid shape.""" + with self.test_session(): + logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) + labels = constant_op.constant([1, 0], dtype=dtypes.float32) + with self.assertRaises(ValueError): + _ = losses.sparse_multiclass_hinge_loss(labels, logits) + + def testNoneWeightRaisesValueError(self): + """An error is raised when weights are None.""" + with self.test_session(): + logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) + labels = constant_op.constant([1, 0]) + with self.assertRaises(ValueError): + _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights=None) + + def testInconsistentLabelsAndWeightsShapesSameRank(self): + """Error raised when weights and labels have same ranks, different sizes.""" + with self.test_session(): + logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1)) + labels = constant_op.constant([1, 0, 2], shape=(3, 1)) + weights = constant_op.constant([1.1, 2.0], shape=(2, 1)) + with self.assertRaises(ValueError): + _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights) + + def testInconsistentLabelsAndWeightsShapesDifferentRank(self): + """Error raised when weights and labels have different ranks and sizes.""" + with self.test_session(): + logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) + labels = constant_op.constant([1, 0], shape=(2, 1)) + weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,)) + with self.assertRaises(ValueError): + _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights) + + def testOutOfRangeLabels(self): + """An error is raised when labels are not in [0, num_classes).""" + with self.test_session(): + logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], + [0.5, 1.8, -1.0]]) + labels = constant_op.constant([1, 0, 4]) + loss = losses.sparse_multiclass_hinge_loss(labels, logits) + with self.assertRaises(errors.InvalidArgumentError): + loss.eval() + + def testZeroLossInt32Labels(self): + """Loss is 0 if true class logits sufficiently higher than other classes.""" + with self.test_session(): + logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], + [0.5, 1.8, -1.0]]) + labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32) + loss = losses.sparse_multiclass_hinge_loss(labels, logits) + self.assertAlmostEqual(loss.eval(), 0.0, 3) + + def testZeroLossInt64Labels(self): + """Loss is 0 if true class logits sufficiently higher than other classes.""" + with self.test_session(): + logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0], + [-0.5, 0.8, -1.0]]) + labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64) + loss = losses.sparse_multiclass_hinge_loss(labels, logits) + self.assertAlmostEqual(loss.eval(), 0.0, 3) + + def testCorrectPredictionsSomeClassesInsideMargin(self): + """Loss is > 0 even if true class logits are higher than other classes.""" + with self.test_session(): + logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0], + [1.5, 1.8, -1.0]]) + labels = constant_op.constant([0, 2, 1]) + loss = losses.sparse_multiclass_hinge_loss(labels, logits) + # The first and third samples incur some loss (0.6 and 0.7 respectively). + self.assertAlmostEqual(loss.eval(), 0.4333, 3) + + def testIncorrectPredictions(self): + """Loss is >0 when an incorrect class has higher logits than true class.""" + with self.test_session(): + logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0], + [0.5, -1.8, 2.0]]) + labels = constant_op.constant([1, 0, 2]) + loss = losses.sparse_multiclass_hinge_loss(labels, logits) + # The first examples incurs a high loss (3.2) since the logits of an + # incorrect class (0) are higher than the logits of the ground truth. The + # second example also incures a (smaller) loss (0.4). + self.assertAlmostEqual(loss.eval(), 1.2, 3) + + def testIncorrectPredictionsColumnLabels(self): + """Same as above but labels is a rank-2 tensor.""" + with self.test_session(): + logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], + [0.2, -1.8, 4.0]]) + labels = constant_op.constant([1, 0, 2], shape=(3, 1)) + loss = losses.sparse_multiclass_hinge_loss(labels, logits) + # The first examples incurs a high loss (3.0) since the logits of an + # incorrect class (0) are higher than the logits of the ground truth. The + # second example also incures a (smaller) loss (0.3). + self.assertAlmostEqual(loss.eval(), 1.1, 3) + + def testIncorrectPredictionsZeroWeights(self): + """Loss is 0 when all weights are missing even if predictions are wrong.""" + with self.test_session(): + logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], + [0.2, -1.8, 4.0]]) + labels = constant_op.constant([1, 0, 2], shape=(3, 1)) + weights = constant_op.constant([0.0, 0.0, 0.0], shape=(3, 1)) + loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights) + # No overall loss since all weights are 0. + self.assertAlmostEqual(loss.eval(), 0.0, 3) + + def testNonZeroLossWithPythonScalarWeights(self): + """Weighted loss is correctly computed when weights is a python scalar.""" + with self.test_session(): + logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], + [0.2, -1.8, 4.0]]) + labels = constant_op.constant([1, 0, 2], shape=(3, 1)) + weights = 10.0 + loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights) + self.assertAlmostEqual(loss.eval(), 11.0, 3) + + def testNonZeroLossWithScalarTensorWeights(self): + """Weighted loss is correctly computed when weights is a rank-0 tensor.""" + with self.test_session(): + logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], + [0.2, -1.8, 4.0]]) + labels = constant_op.constant([1, 0, 2], shape=(3, 1)) + weights = constant_op.constant(5.0) + loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights) + self.assertAlmostEqual(loss.eval(), 5.5, 3) + + def testNonZeroLossWith1DTensorWeightsColumnLabels(self): + """Weighted loss is correctly computed when weights is a rank-0 tensor.""" + with self.test_session(): + logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], + [0.2, -1.8, 4.0]]) + labels = constant_op.constant([1, 0, 2], shape=(3, 1)) + weights = constant_op.constant([1.0, 0.5, 2.0], shape=(3,)) + loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights) + # The overall loss is 1/3 *(3.0*1.0 + 0.5*0.3+ 2.0*0.0) = 1.05 + self.assertAlmostEqual(loss.eval(), 1.05, 3) + + def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self): + """Weighted loss is correctly computed when weights is a rank-0 tensor.""" + with self.test_session(): + logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], + [0.2, -1.8, 4.0], [1.6, 1.8, -4.0]]) + labels = constant_op.constant([1, 0, 2, 1]) + weights = constant_op.constant([[1.0], [0.0], [2.0], [4.0]]) + loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights) + # The overall loss is 1/3 *(3.0*1.0 + 0.0*0.3+ 2.0*0.0 + 4.0*0.8) = 6.2/3. + self.assertAlmostEqual(loss.eval(), 2.06666, 3) + + +if __name__ == '__main__': + test.main() |