aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/kernel_methods
diff options
context:
space:
mode:
authorGravatar Petros Mol <pmol@google.com>2017-07-17 15:11:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-17 15:15:41 -0700
commit5799c01281ff607214f755693b98d03dd0847e18 (patch)
tree4a65f30353bb2a44349d86001e0ff9d53b843a0e /tensorflow/contrib/kernel_methods
parent625ddbaf2d64d5b53551480eb51d2d1c02775d2d (diff)
Adds multi-class (Crammer-Singer) hinge loss to tf.contrib.kernel_methods
PiperOrigin-RevId: 162277054
Diffstat (limited to 'tensorflow/contrib/kernel_methods')
-rw-r--r--tensorflow/contrib/kernel_methods/BUILD17
-rw-r--r--tensorflow/contrib/kernel_methods/__init__.py2
-rw-r--r--tensorflow/contrib/kernel_methods/python/losses.py135
-rw-r--r--tensorflow/contrib/kernel_methods/python/losses_test.py206
4 files changed, 360 insertions, 0 deletions
diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD
index fccaa3abd4..ae1402b0e6 100644
--- a/tensorflow/contrib/kernel_methods/BUILD
+++ b/tensorflow/contrib/kernel_methods/BUILD
@@ -14,6 +14,7 @@ py_library(
srcs = [
"__init__.py",
"python/kernel_estimators.py",
+ "python/losses.py",
"python/mappers/random_fourier_features.py",
],
srcs_version = "PY2AND3",
@@ -22,11 +23,15 @@ py_library(
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:util",
+ "//tensorflow/python/ops/losses",
"//third_party/py/numpy",
"@six_archive//:six",
],
@@ -71,6 +76,18 @@ py_test(
],
)
+py_test(
+ name = "losses_test",
+ srcs = ["python/losses_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":kernel_methods",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/kernel_methods/__init__.py b/tensorflow/contrib/kernel_methods/__init__.py
index 7272e59516..0f3827d187 100644
--- a/tensorflow/contrib/kernel_methods/__init__.py
+++ b/tensorflow/contrib/kernel_methods/__init__.py
@@ -16,12 +16,14 @@
@@KernelLinearClassifier
@@RandomFourierFeatureMapper
+@@sparse_multiclass_hinge_loss
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.kernel_methods.python.kernel_estimators import KernelLinearClassifier
+from tensorflow.contrib.kernel_methods.python.losses import sparse_multiclass_hinge_loss
from tensorflow.contrib.kernel_methods.python.mappers.random_fourier_features import RandomFourierFeatureMapper
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py
new file mode 100644
index 0000000000..208b0e1c9d
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/python/losses.py
@@ -0,0 +1,135 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implementation of kernel-methods-related loss operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops.losses import losses
+
+
+def sparse_multiclass_hinge_loss(
+ labels,
+ logits,
+ weights=1.0,
+ scope=None,
+ loss_collection=ops.GraphKeys.LOSSES,
+ reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS):
+ """Adds Ops for computing the multiclass hinge loss.
+
+ The implementation is based on the following paper:
+ On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines
+ by Crammer and Singer.
+ link: http://jmlr.csail.mit.edu/papers/volume2/crammer01a/crammer01a.pdf
+
+ This is a generalization of standard (binary) hinge loss. For a given instance
+ with correct label c*, the loss is given by:
+ loss = max_{c != c*} logits_c - logits_{c*} + 1.
+ or equivalently
+ loss = max_c { logits_c - logits_{c*} + I_{c != c*} }
+ where I_{c != c*} = 1 if c != c* and 0 otherwise.
+
+ Args:
+ labels: `Tensor` of shape [batch_size] or [batch_size, 1]. Corresponds to
+ the ground truth. Each entry must be an index in `[0, num_classes)`.
+ logits: `Tensor` of shape [batch_size, num_classes] corresponding to the
+ unscaled logits. Its dtype should be either `float32` or `float64`.
+ weights: Optional (python) scalar or `Tensor`. If a non-scalar `Tensor`, its
+ rank should be either 1 ([batch_size]) or 2 ([batch_size, 1]).
+ scope: The scope for the operations performed in computing the loss.
+ loss_collection: collection to which the loss will be added.
+ reduction: Type of reduction to apply to loss.
+
+ Returns:
+ Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same
+ shape as `labels`; otherwise, it is a scalar.
+
+ Raises:
+ ValueError: If `logits`, `labels` or `weights` have invalid or inconsistent
+ shapes.
+ ValueError: If `labels` tensor has invalid dtype.
+ """
+
+ with ops.name_scope(scope, 'sparse_multiclass_hinge_loss', (logits,
+ labels)) as scope:
+
+ # Check logits Tensor has valid rank.
+ logits_shape = logits.get_shape()
+ logits_rank = logits_shape.ndims
+ if logits_rank != 2:
+ raise ValueError(
+ 'logits should have rank 2 ([batch_size, num_classes]). Given rank is'
+ ' {}'.format(logits_rank))
+ batch_size, num_classes = logits_shape[0].value, logits_shape[1].value
+ logits = math_ops.to_float(logits)
+
+ # Check labels have valid type.
+ if labels.dtype != dtypes.int32 and labels.dtype != dtypes.int64:
+ raise ValueError(
+ 'Invalid dtype for labels: {}. Acceptable dtypes: int32 and int64'.
+ format(labels.dtype))
+
+ # Check labels and weights have valid ranks and are consistent.
+ labels_rank = labels.get_shape().ndims
+ if labels_rank not in [1, 2]:
+ raise ValueError(
+ 'labels should have rank 1 ([batch_size]) or 2 ([batch_size, 1]). '
+ 'Given rank is {}'.format(labels_rank))
+ with ops.control_dependencies([
+ check_ops.assert_less(labels, math_ops.cast(num_classes, labels.dtype))
+ ]):
+ labels = array_ops.reshape(labels, shape=[-1])
+
+ weights = ops.convert_to_tensor(weights)
+ weights_rank = weights.get_shape().ndims
+ if weights_rank not in [0, 1, 2]:
+ raise ValueError(
+ 'non-scalar weights should have rank 1 ([batch_size]) or 2 '
+ '([batch_size, 1]). Given rank is {}'.format(labels_rank))
+
+ if weights_rank > 0:
+ weights = array_ops.reshape(weights, shape=[-1])
+ # Check weights and labels have the same number of elements.
+ weights.get_shape().assert_is_compatible_with(labels.get_shape())
+
+ # Compute the logits tensor corresponding to the correct class per instance.
+ example_indices = array_ops.reshape(
+ math_ops.range(batch_size), shape=[batch_size, 1])
+ indices = array_ops.concat(
+ [
+ example_indices,
+ array_ops.reshape(
+ math_ops.cast(labels, example_indices.dtype),
+ shape=[batch_size, 1])
+ ],
+ axis=1)
+ label_logits = array_ops.reshape(
+ array_ops.gather_nd(params=logits, indices=indices),
+ shape=[batch_size, 1])
+
+ one_cold_labels = array_ops.one_hot(
+ indices=labels, depth=num_classes, on_value=0.0, off_value=1.0)
+ margin = logits - label_logits + one_cold_labels
+ margin = nn_ops.relu(margin)
+ loss = math_ops.reduce_max(margin, axis=1)
+ return losses.compute_weighted_loss(
+ loss, weights, scope, loss_collection, reduction=reduction)
diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py
new file mode 100644
index 0000000000..8a1a5ffe56
--- /dev/null
+++ b/tensorflow/contrib/kernel_methods/python/losses_test.py
@@ -0,0 +1,206 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for third_party.tensorflow.contrib.kernel_methods.python.losses."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.kernel_methods.python import losses
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class SparseMulticlassHingeLossTest(test.TestCase):
+
+ def testInvalidLogitsShape(self):
+ """An error is raised when logits have invalid shape."""
+ with self.test_session():
+ logits = constant_op.constant([-1.0, 2.1], shape=(2,))
+ labels = constant_op.constant([0, 1])
+ with self.assertRaises(ValueError):
+ _ = losses.sparse_multiclass_hinge_loss(labels, logits)
+
+ def testInvalidLabelsShape(self):
+ """An error is raised when labels have invalid shape."""
+ with self.test_session():
+ logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
+ labels = constant_op.constant([1, 0], shape=(1, 1, 2))
+ with self.assertRaises(ValueError):
+ _ = losses.sparse_multiclass_hinge_loss(labels, logits)
+
+ def testInvalidWeightsShape(self):
+ """An error is raised when weights have invalid shape."""
+ with self.test_session():
+ logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
+ labels = constant_op.constant([1, 0], shape=(2,))
+ weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1))
+ with self.assertRaises(ValueError):
+ _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
+
+ def testInvalidLabelsDtype(self):
+ """An error is raised when labels have invalid shape."""
+ with self.test_session():
+ logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
+ labels = constant_op.constant([1, 0], dtype=dtypes.float32)
+ with self.assertRaises(ValueError):
+ _ = losses.sparse_multiclass_hinge_loss(labels, logits)
+
+ def testNoneWeightRaisesValueError(self):
+ """An error is raised when weights are None."""
+ with self.test_session():
+ logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
+ labels = constant_op.constant([1, 0])
+ with self.assertRaises(ValueError):
+ _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights=None)
+
+ def testInconsistentLabelsAndWeightsShapesSameRank(self):
+ """Error raised when weights and labels have same ranks, different sizes."""
+ with self.test_session():
+ logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1))
+ labels = constant_op.constant([1, 0, 2], shape=(3, 1))
+ weights = constant_op.constant([1.1, 2.0], shape=(2, 1))
+ with self.assertRaises(ValueError):
+ _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
+
+ def testInconsistentLabelsAndWeightsShapesDifferentRank(self):
+ """Error raised when weights and labels have different ranks and sizes."""
+ with self.test_session():
+ logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
+ labels = constant_op.constant([1, 0], shape=(2, 1))
+ weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,))
+ with self.assertRaises(ValueError):
+ _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
+
+ def testOutOfRangeLabels(self):
+ """An error is raised when labels are not in [0, num_classes)."""
+ with self.test_session():
+ logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
+ [0.5, 1.8, -1.0]])
+ labels = constant_op.constant([1, 0, 4])
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits)
+ with self.assertRaises(errors.InvalidArgumentError):
+ loss.eval()
+
+ def testZeroLossInt32Labels(self):
+ """Loss is 0 if true class logits sufficiently higher than other classes."""
+ with self.test_session():
+ logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
+ [0.5, 1.8, -1.0]])
+ labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits)
+ self.assertAlmostEqual(loss.eval(), 0.0, 3)
+
+ def testZeroLossInt64Labels(self):
+ """Loss is 0 if true class logits sufficiently higher than other classes."""
+ with self.test_session():
+ logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0],
+ [-0.5, 0.8, -1.0]])
+ labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64)
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits)
+ self.assertAlmostEqual(loss.eval(), 0.0, 3)
+
+ def testCorrectPredictionsSomeClassesInsideMargin(self):
+ """Loss is > 0 even if true class logits are higher than other classes."""
+ with self.test_session():
+ logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0],
+ [1.5, 1.8, -1.0]])
+ labels = constant_op.constant([0, 2, 1])
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits)
+ # The first and third samples incur some loss (0.6 and 0.7 respectively).
+ self.assertAlmostEqual(loss.eval(), 0.4333, 3)
+
+ def testIncorrectPredictions(self):
+ """Loss is >0 when an incorrect class has higher logits than true class."""
+ with self.test_session():
+ logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0],
+ [0.5, -1.8, 2.0]])
+ labels = constant_op.constant([1, 0, 2])
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits)
+ # The first examples incurs a high loss (3.2) since the logits of an
+ # incorrect class (0) are higher than the logits of the ground truth. The
+ # second example also incures a (smaller) loss (0.4).
+ self.assertAlmostEqual(loss.eval(), 1.2, 3)
+
+ def testIncorrectPredictionsColumnLabels(self):
+ """Same as above but labels is a rank-2 tensor."""
+ with self.test_session():
+ logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
+ [0.2, -1.8, 4.0]])
+ labels = constant_op.constant([1, 0, 2], shape=(3, 1))
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits)
+ # The first examples incurs a high loss (3.0) since the logits of an
+ # incorrect class (0) are higher than the logits of the ground truth. The
+ # second example also incures a (smaller) loss (0.3).
+ self.assertAlmostEqual(loss.eval(), 1.1, 3)
+
+ def testIncorrectPredictionsZeroWeights(self):
+ """Loss is 0 when all weights are missing even if predictions are wrong."""
+ with self.test_session():
+ logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
+ [0.2, -1.8, 4.0]])
+ labels = constant_op.constant([1, 0, 2], shape=(3, 1))
+ weights = constant_op.constant([0.0, 0.0, 0.0], shape=(3, 1))
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
+ # No overall loss since all weights are 0.
+ self.assertAlmostEqual(loss.eval(), 0.0, 3)
+
+ def testNonZeroLossWithPythonScalarWeights(self):
+ """Weighted loss is correctly computed when weights is a python scalar."""
+ with self.test_session():
+ logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
+ [0.2, -1.8, 4.0]])
+ labels = constant_op.constant([1, 0, 2], shape=(3, 1))
+ weights = 10.0
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
+ self.assertAlmostEqual(loss.eval(), 11.0, 3)
+
+ def testNonZeroLossWithScalarTensorWeights(self):
+ """Weighted loss is correctly computed when weights is a rank-0 tensor."""
+ with self.test_session():
+ logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
+ [0.2, -1.8, 4.0]])
+ labels = constant_op.constant([1, 0, 2], shape=(3, 1))
+ weights = constant_op.constant(5.0)
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
+ self.assertAlmostEqual(loss.eval(), 5.5, 3)
+
+ def testNonZeroLossWith1DTensorWeightsColumnLabels(self):
+ """Weighted loss is correctly computed when weights is a rank-0 tensor."""
+ with self.test_session():
+ logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
+ [0.2, -1.8, 4.0]])
+ labels = constant_op.constant([1, 0, 2], shape=(3, 1))
+ weights = constant_op.constant([1.0, 0.5, 2.0], shape=(3,))
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
+ # The overall loss is 1/3 *(3.0*1.0 + 0.5*0.3+ 2.0*0.0) = 1.05
+ self.assertAlmostEqual(loss.eval(), 1.05, 3)
+
+ def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self):
+ """Weighted loss is correctly computed when weights is a rank-0 tensor."""
+ with self.test_session():
+ logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
+ [0.2, -1.8, 4.0], [1.6, 1.8, -4.0]])
+ labels = constant_op.constant([1, 0, 2, 1])
+ weights = constant_op.constant([[1.0], [0.0], [2.0], [4.0]])
+ loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
+ # The overall loss is 1/3 *(3.0*1.0 + 0.0*0.3+ 2.0*0.0 + 4.0*0.8) = 6.2/3.
+ self.assertAlmostEqual(loss.eval(), 2.06666, 3)
+
+
+if __name__ == '__main__':
+ test.main()