diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-02-08 09:25:09 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-08 09:50:05 -0800 |
commit | 639b4e71f532761a4840b1cdbaea55ad0917c75b (patch) | |
tree | 5116415b1d9ff82f054dd4feeadd81cb833d6435 /tensorflow/contrib/sparsemax | |
parent | 15ff7b702788c0cf75bb8d5ce090f06490098cf7 (diff) |
Merge changes from github.
Change: 146918929
Diffstat (limited to 'tensorflow/contrib/sparsemax')
6 files changed, 715 insertions, 0 deletions
diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD new file mode 100644 index 0000000000..bd59c626f2 --- /dev/null +++ b/tensorflow/contrib/sparsemax/BUILD @@ -0,0 +1,76 @@ +# Description: +# Contains ops to train linear models on top of TensorFlow. +# APIs here are meant to evolve over time. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +package(default_visibility = ["//visibility:public"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_tests") +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_py_test", +) +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_kernel_tests_linkstatic", +) + +py_library( + name = "sparsemax_py", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + ], +) + +cuda_py_tests( + name = "sparsemax_test", + size = "small", + srcs = ["python/kernel_tests/sparsemax_test.py"], + additional_deps = [ + ":sparsemax_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "sparsemax_loss_test", + size = "medium", + srcs = ["python/kernel_tests/sparsemax_loss_test.py"], + additional_deps = [ + ":sparsemax_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/sparsemax/__init__.py b/tensorflow/contrib/sparsemax/__init__.py new file mode 100644 index 0000000000..0be4988dbf --- /dev/null +++ b/tensorflow/contrib/sparsemax/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module that implements sparsemax and sparsemax loss, see [1]. + +[1] https://arxiv.org/abs/1602.02068 + +## Sparsemax + +@@sparsemax +@@sparsemax_loss +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.sparsemax.python.ops.sparsemax import sparsemax +from tensorflow.contrib.sparsemax.python.ops.sparsemax_loss \ + import sparsemax_loss diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py new file mode 100644 index 0000000000..89dbcd96f8 --- /dev/null +++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py @@ -0,0 +1,224 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SparsemaxLossOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.sparsemax import sparsemax, sparsemax_loss +from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + +test_obs = 10 + + +class SparsemaxLossTest(test.TestCase): + + def _np_sparsemax(self, z): + z = z - np.mean(z, axis=1)[:, np.newaxis] + + # sort z + z_sorted = np.sort(z, axis=1)[:, ::-1] + + # calculate k(z) + z_cumsum = np.cumsum(z_sorted, axis=1) + k = np.arange(1, z.shape[1] + 1) + z_check = 1 + k * z_sorted > z_cumsum + # use argmax to get the index by row as .nonzero() doesn't + # take an axis argument. np.argmax return the first index, but the last + # index is required here, use np.flip to get the last index and + # `z.shape[axis]` to compensate for np.flip afterwards. + k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1) + + # calculate tau(z) + tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1] + tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1) + + # calculate p + return np.maximum(0, z - tau_z) + + def _np_sparsemax_loss(self, z, q): + z = z - np.mean(z, axis=1)[:, np.newaxis] + + # Calculate q^T * z + z_k = np.sum(q * z, axis=1) + + # calculate sum over S(z) + p = self._np_sparsemax(z) + s = p > 0 + # z_i^2 - tau(z)^2 = p_i (2 * z_i - p_i) for i \in S(z) + S_sum = np.sum(s * p * (2 * z - p), axis=1) + + # because q is binary, sum([q_1^2, q_2^2, ...]) is just sum(q) + q_norm = np.sum(q, axis=1) + + return -z_k + 0.5 * S_sum + 0.5 * q_norm + + def _np_sparsemax_loss_grad(self, z, q): + # chain rule + grad = 1 + + return grad * (-q + self._np_sparsemax(z)) + + def _tf_sparsemax(self, z, dtype, use_gpu): + with self.test_session(use_gpu=use_gpu): + tf_sparsemax_op = sparsemax(z.astype(dtype)) + tf_sparsemax_out = tf_sparsemax_op.eval() + + return tf_sparsemax_op, tf_sparsemax_out + + def _tf_sparsemax_loss(self, z, q, dtype, use_gpu): + z = z.astype(dtype) + q = q.astype(dtype) + + with self.test_session(use_gpu=use_gpu): + tf_sparsemax_op = sparsemax(z) + tf_loss_op = sparsemax_loss(z, tf_sparsemax_op, q) + tf_loss_out = tf_loss_op.eval() + + return tf_loss_op, tf_loss_out + + def _test_sparsemax_loss_against_numpy(self, dtype, random, use_gpu): + """check sparsemax-loss kernel against numpy""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)) + q = np.zeros((test_obs, 10)) + q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1 + + tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu) + np_loss = self._np_sparsemax_loss(z, q).astype(dtype) + + self.assertAllCloseAccordingToType(np_loss, tf_loss_out, + half_atol=1e-2, half_rtol=5e-3) + self.assertShapeEqual(np_loss, tf_loss_op) + + def _test_constant_add(self, dtype, random, use_gpu): + """check sparsemax-loss proposition 3""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)) + c = random.uniform(low=-3, high=3, size=(test_obs, 1)) + q = np.zeros((test_obs, 10)) + q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1 + + _, tf_loss_zpc = self._tf_sparsemax_loss( + z + c, q, dtype, use_gpu + ) + + _, tf_loss_z = self._tf_sparsemax_loss( + z, q, dtype, use_gpu + ) + + self.assertAllCloseAccordingToType(tf_loss_zpc, tf_loss_z, + float_atol=5e-6, float_rtol=5e-6, + half_atol=1e-2, half_rtol=1e-2) + + def _test_sparsemax_loss_positive(self, dtype, random, use_gpu): + """check sparsemax-loss proposition 4""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)) + q = np.zeros((test_obs, 10)) + q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1 + + tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu) + + self.assertAllCloseAccordingToType(np.abs(tf_loss_out), tf_loss_out) + self.assertShapeEqual(np.zeros(test_obs), tf_loss_op) + + def _test_sparsemax_loss_zero(self, dtype, random, use_gpu): + """check sparsemax-loss proposition 5""" + # construct z and q, such that z_k >= 1 + max_{j!=k} z_k holds for + # delta_0 = 1. + z = random.uniform(low=-3, high=3, size=(test_obs, 10)) + z[:, 0] = np.max(z, axis=1) + 1.05 + + q = np.zeros((test_obs, 10)) + q[:, 0] = 1 + + tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu) + tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu) + + self.assertAllCloseAccordingToType(np.zeros(test_obs), tf_loss_out) + self.assertShapeEqual(np.zeros(test_obs), tf_loss_op) + + self.assertAllCloseAccordingToType(q, tf_sparsemax_out) + self.assertShapeEqual(q, tf_sparsemax_op) + + def _test_gradient_against_estimate(self, dtype, random, use_gpu): + """check sparsemax-loss Rop, aginst estimated-loss Rop""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype) + q = np.zeros((test_obs, 10)).astype(dtype) + q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1 + + logits = array_ops.placeholder(dtype, name='z') + sparsemax_op = sparsemax(logits) + loss_op = sparsemax_loss(logits, sparsemax_op, q) + + with self.test_session(use_gpu=use_gpu): + err = gradient_checker.compute_gradient_error( + logits, z.shape, + loss_op, (test_obs, ), + x_init_value=z, delta=1e-9 + ) + + self.assertLess(err, 1e-4) + + def _test_gradient_against_numpy(self, dtype, random, use_gpu): + """check sparsemax-loss Rop, aginst numpy Rop""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)) + q = np.zeros((test_obs, 10)) + q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1 + + logits = constant_op.constant(z.astype(dtype), name='z') + sparsemax_op = sparsemax(logits) + loss_op = sparsemax_loss(logits, sparsemax_op, q.astype(dtype)) + loss_grad_op = gradients_impl.gradients(loss_op, [logits])[0] + + with self.test_session(use_gpu=use_gpu): + tf_grad = loss_grad_op.eval() + np_grad = self._np_sparsemax_loss_grad(z, q).astype(dtype) + + self.assertAllCloseAccordingToType(np_grad, tf_grad, + half_atol=1e-2, half_rtol=5e-3) + self.assertShapeEqual(np_grad, loss_grad_op) + + def _test_dtype(self, dtype): + random = np.random.RandomState(1) + + self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False) + + self._test_constant_add(dtype, random, use_gpu=False) + + self._test_sparsemax_loss_positive(dtype, random, use_gpu=False) + + self._test_sparsemax_loss_zero(dtype, random, use_gpu=False) + + # sparsemax is not a smooth function so gradient estimation is only + # possibol for float64. + if dtype == 'float64': + self._test_gradient_against_estimate(dtype, random, use_gpu=False) + + self._test_gradient_against_numpy(dtype, random, use_gpu=False) + + def testFloat(self): + self._test_dtype('float32') + + def testDouble(self): + self._test_dtype('float64') + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py new file mode 100644 index 0000000000..eafac1b9ae --- /dev/null +++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py @@ -0,0 +1,252 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SparsemaxOp.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.sparsemax import sparsemax +from tensorflow.python.ops import gradient_checker +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + +test_obs = 10 + + +class SparsemaxTest(test.TestCase): + + def _np_sparsemax(self, z): + z = z - np.mean(z, axis=1)[:, np.newaxis] + + # sort z + z_sorted = np.sort(z, axis=1)[:, ::-1] + + # calculate k(z) + z_cumsum = np.cumsum(z_sorted, axis=1) + k = np.arange(1, z.shape[1] + 1) + z_check = 1 + k * z_sorted > z_cumsum + # use argmax to get the index by row as .nonzero() doesn't + # take an axis argument. np.argmax return the first index, but the last + # index is required here, use np.flip to get the last index and + # `z.shape[axis]` to compensate for np.flip afterwards. + k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1) + + # calculate tau(z) + tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1] + tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1) + + # calculate p + return np.maximum(0, z - tau_z) + + def _np_sparsemax_grad(self, z): + # chain rule + grad = np.ones_like(z) + + # Construct S(z) + probability = self._np_sparsemax(z) + support = probability > 0 + + # Calculate \hat{v}, which will be a vector (scalar for each z) + v_hat = np.sum(grad * support, axis=1) / np.sum(support, axis=1) + + # Calculates J(z) * v + return support * (grad - v_hat[:, np.newaxis]) + + def _tf_sparsemax(self, z, dtype, use_gpu): + with self.test_session(use_gpu=use_gpu): + tf_sparsemax_op = sparsemax(z.astype(dtype)) + tf_sparsemax_out = tf_sparsemax_op.eval() + + return tf_sparsemax_op, tf_sparsemax_out + + def _test_sparsemax_against_numpy(self, dtype, random, use_gpu): + """check sparsemax kernel against numpy""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)) + + tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu) + p_sparemax = self._np_sparsemax(z).astype(dtype) + + self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out, + half_atol=5e-3) + self.assertShapeEqual(p_sparemax, tf_sparsemax_op) + + def _test_sparsemax_of_zero(self, dtype, random, use_gpu): + """check sparsemax proposition 1, part 1""" + z = np.zeros((1, 10)) + + tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu) + p_sparemax = np.ones_like(z, dtype=dtype) / z.size + + self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out) + self.assertShapeEqual(p_sparemax, tf_sparsemax_op) + + def _test_sparsemax_of_inf(self, dtype, random, use_gpu): + """check sparsemax proposition 1, part 2""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)) + + # assume |A(z)| = 1, as z is continues random + z_sort_arg = np.argsort(z, axis=1)[:, ::-1] + z_sort = np.sort(z, axis=-1)[:, ::-1] + gamma_z = z_sort[:, 0] - z_sort[:, 1] + epsilon = (0.99 * gamma_z * 1).reshape(-1, 1) + + # construct the expected 1_A(z) array + p_expected = np.zeros((test_obs, 10), dtype=dtype) + p_expected[np.arange(0, test_obs), z_sort_arg[:, 0]] = 1 + + tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax( + (1 / epsilon) * z, dtype, use_gpu + ) + + self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out) + self.assertShapeEqual(p_expected, tf_sparsemax_op) + + def _test_constant_add(self, dtype, random, use_gpu): + """check sparsemax proposition 2""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype) + c = random.uniform(low=-3, high=3, size=(test_obs, 1)).astype(dtype) + + _, tf_sparsemax_zpc = self._tf_sparsemax( + z + c, dtype, use_gpu + ) + + _, tf_sparsemax_z = self._tf_sparsemax( + z, dtype, use_gpu + ) + + self.assertAllCloseAccordingToType(tf_sparsemax_zpc, tf_sparsemax_z, + half_atol=5e-3) + + def _test_permutation(self, dtype, random, use_gpu): + """check sparsemax proposition 3""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)) + _, p = self._tf_sparsemax(z, dtype, use_gpu) + + for i in range(test_obs): + per = random.permutation(10) + + tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax( + z[i, per].reshape(1, -1), dtype, use_gpu + ) + p_expected = p[i, per].reshape(1, -1) + + self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out, + half_atol=5e-3) + self.assertShapeEqual(p_expected, tf_sparsemax_op) + + def _test_diffrence(self, dtype, random, use_gpu): + """check sparsemax proposition 4""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)) + _, p = self._tf_sparsemax(z, dtype, use_gpu) + + etol = {'float16': 1e-2, 'float32': 1e-6, 'float64': 1e-9}[dtype] + + for val in range(0, test_obs): + for i in range(0, 10): + for j in range(0, 10): + # check condition, the obesite pair will be checked anyway + if z[val, i] > z[val, j]: + continue + + self.assertTrue( + 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol, + "0 <= %.10f <= %.10f" % ( + p[val, j] - p[val, i], z[val, j] - z[val, i] + etol + ) + ) + + def _test_two_dimentional(self, dtype, random, use_gpu): + """check two dimentation sparsemax case""" + t = np.linspace(-2, 2, test_obs, dtype=dtype) + z = np.vstack([ + t, np.zeros(test_obs, dtype=dtype) + ]).T + + tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu) + + p0_expected = np.select([t < -1, t <= 1, t > 1], [0, (t + 1) / 2, 1]) + + self.assertAllCloseAccordingToType(p0_expected, tf_sparsemax_out[:, 0]) + self.assertAllCloseAccordingToType(1 - p0_expected, tf_sparsemax_out[:, 1]) + self.assertShapeEqual(z, tf_sparsemax_op) + + def _test_gradient_against_estimate(self, dtype, random, use_gpu): + """check sparsemax Rop, aginst estimated Rop""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype) + + logits = array_ops.placeholder(dtype, name='z') + sparsemax_op = sparsemax(logits) + + with self.test_session(use_gpu=use_gpu): + err = gradient_checker.compute_gradient_error( + logits, z.shape, + sparsemax_op, z.shape, + x_init_value=z, delta=1e-9 + ) + + self.assertLess(err, 1e-4) + + def _test_gradient_against_numpy(self, dtype, random, use_gpu): + """check sparsemax Rop, aginst numpy Rop""" + z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype) + + logits = constant_op.constant(z, name='z') + sparsemax_op = sparsemax(logits) + sparsemax_grad_op = gradients_impl.gradients(sparsemax_op, [logits])[0] + + with self.test_session(use_gpu=use_gpu): + tf_grad = sparsemax_grad_op.eval() + np_grad = self._np_sparsemax_grad(z) + + self.assertAllCloseAccordingToType(np_grad, tf_grad) + self.assertShapeEqual(np_grad, sparsemax_grad_op) + + def _test_dtype(self, dtype): + random = np.random.RandomState(1) + + self._test_sparsemax_against_numpy(dtype, random, use_gpu=False) + + self._test_sparsemax_of_zero(dtype, random, use_gpu=False) + + self._test_sparsemax_of_inf(dtype, random, use_gpu=False) + + self._test_constant_add(dtype, random, use_gpu=False) + + self._test_permutation(dtype, random, use_gpu=False) + + self._test_diffrence(dtype, random, use_gpu=False) + + self._test_two_dimentional(dtype, random, use_gpu=False) + + # sparsemax is not a smooth function so gradient estimation is only + # possibol for float64. + if dtype == 'float64': + self._test_gradient_against_estimate(dtype, random, use_gpu=False) + + self._test_gradient_against_numpy(dtype, random, use_gpu=False) + + def testFloat(self): + self._test_dtype('float32') + + def testDouble(self): + self._test_dtype('float64') + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py new file mode 100644 index 0000000000..6e1cd75f22 --- /dev/null +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py @@ -0,0 +1,74 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sparsemax op.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader +from tensorflow.python.framework import ops, dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn + + +def sparsemax(logits, name=None): + """Computes sparsemax activations [1]. + + For each batch `i` and class `j` we have + sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0) + + [1]: https://arxiv.org/abs/1602.02068 + + Args: + logits: A `Tensor`. Must be one of the following types: `half`, `float32`, + `float64`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `logits`. + """ + + with ops.name_scope(name, "sparsemax", [logits]) as name: + logits = ops.convert_to_tensor(logits, name="logits") + obs = array_ops.shape(logits)[0] + dims = array_ops.shape(logits)[1] + + z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis] + + # sort z + z_sorted, _ = nn.top_k(z, k=dims) + + # calculate k(z) + z_cumsum = math_ops.cumsum(z_sorted, axis=1) + k = math_ops.range( + 1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype + ) + z_check = 1 + k * z_sorted > z_cumsum + # because the z_check vector is always [1,1,...1,0,0,...0] finding the + # (index + 1) of the last `1` is the same as just summing the number of 1. + k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1) + + # calculate tau(z) + indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1) + tau_sum = array_ops.gather_nd(z_cumsum, indices) + tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype) + + # calculate p + return math_ops.maximum( + math_ops.cast(0, logits.dtype), + z - tau_z[:, array_ops.newaxis] + ) diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py new file mode 100644 index 0000000000..1f5e8c37e3 --- /dev/null +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py @@ -0,0 +1,59 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sparsemax Loss op.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def sparsemax_loss(logits, sparsemax, labels, name=None): + """Computes sparsemax loss function [1]. + + [1]: https://arxiv.org/abs/1602.02068 + + Args: + logits: A `Tensor`. Must be one of the following types: `half`, `float32`, + `float64`. + sparsemax: A `Tensor`. Must have the same type as `logits`. + labels: A `Tensor`. Must have the same type as `logits`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `logits`. + """ + + with ops.name_scope(name, "sparsemax_loss", + [logits, sparsemax, labels]) as name: + logits = ops.convert_to_tensor(logits, name="logits") + sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax") + labels = ops.convert_to_tensor(labels, name="labels") + + shifted_logits = logits - \ + math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis] + + # sum over support + support = math_ops.cast(sparsemax > 0, sparsemax.dtype) + sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax) + + # - z_k + ||q||^2 + q_part = labels * (0.5 * labels - shifted_logits) + + return math_ops.reduce_sum(sum_s + q_part, axis=1) |