aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/sparsemax
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-02-08 09:25:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-08 09:50:05 -0800
commit639b4e71f532761a4840b1cdbaea55ad0917c75b (patch)
tree5116415b1d9ff82f054dd4feeadd81cb833d6435 /tensorflow/contrib/sparsemax
parent15ff7b702788c0cf75bb8d5ce090f06490098cf7 (diff)
Merge changes from github.
Change: 146918929
Diffstat (limited to 'tensorflow/contrib/sparsemax')
-rw-r--r--tensorflow/contrib/sparsemax/BUILD76
-rw-r--r--tensorflow/contrib/sparsemax/__init__.py30
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py224
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py252
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax.py74
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py59
6 files changed, 715 insertions, 0 deletions
diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD
new file mode 100644
index 0000000000..bd59c626f2
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/BUILD
@@ -0,0 +1,76 @@
+# Description:
+# Contains ops to train linear models on top of TensorFlow.
+# APIs here are meant to evolve over time.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//visibility:public"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_library",
+ "tf_py_test",
+)
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_kernel_tests_linkstatic",
+)
+
+py_library(
+ name = "sparsemax_py",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
+ ],
+)
+
+cuda_py_tests(
+ name = "sparsemax_test",
+ size = "small",
+ srcs = ["python/kernel_tests/sparsemax_test.py"],
+ additional_deps = [
+ ":sparsemax_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_tests(
+ name = "sparsemax_loss_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/sparsemax_loss_test.py"],
+ additional_deps = [
+ ":sparsemax_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/sparsemax/__init__.py b/tensorflow/contrib/sparsemax/__init__.py
new file mode 100644
index 0000000000..0be4988dbf
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module that implements sparsemax and sparsemax loss, see [1].
+
+[1] https://arxiv.org/abs/1602.02068
+
+## Sparsemax
+
+@@sparsemax
+@@sparsemax_loss
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.sparsemax.python.ops.sparsemax import sparsemax
+from tensorflow.contrib.sparsemax.python.ops.sparsemax_loss \
+ import sparsemax_loss
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
new file mode 100644
index 0000000000..89dbcd96f8
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
@@ -0,0 +1,224 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SparsemaxLossOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.sparsemax import sparsemax, sparsemax_loss
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+test_obs = 10
+
+
+class SparsemaxLossTest(test.TestCase):
+
+ def _np_sparsemax(self, z):
+ z = z - np.mean(z, axis=1)[:, np.newaxis]
+
+ # sort z
+ z_sorted = np.sort(z, axis=1)[:, ::-1]
+
+ # calculate k(z)
+ z_cumsum = np.cumsum(z_sorted, axis=1)
+ k = np.arange(1, z.shape[1] + 1)
+ z_check = 1 + k * z_sorted > z_cumsum
+ # use argmax to get the index by row as .nonzero() doesn't
+ # take an axis argument. np.argmax return the first index, but the last
+ # index is required here, use np.flip to get the last index and
+ # `z.shape[axis]` to compensate for np.flip afterwards.
+ k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1)
+
+ # calculate tau(z)
+ tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1]
+ tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1)
+
+ # calculate p
+ return np.maximum(0, z - tau_z)
+
+ def _np_sparsemax_loss(self, z, q):
+ z = z - np.mean(z, axis=1)[:, np.newaxis]
+
+ # Calculate q^T * z
+ z_k = np.sum(q * z, axis=1)
+
+ # calculate sum over S(z)
+ p = self._np_sparsemax(z)
+ s = p > 0
+ # z_i^2 - tau(z)^2 = p_i (2 * z_i - p_i) for i \in S(z)
+ S_sum = np.sum(s * p * (2 * z - p), axis=1)
+
+ # because q is binary, sum([q_1^2, q_2^2, ...]) is just sum(q)
+ q_norm = np.sum(q, axis=1)
+
+ return -z_k + 0.5 * S_sum + 0.5 * q_norm
+
+ def _np_sparsemax_loss_grad(self, z, q):
+ # chain rule
+ grad = 1
+
+ return grad * (-q + self._np_sparsemax(z))
+
+ def _tf_sparsemax(self, z, dtype, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ tf_sparsemax_op = sparsemax(z.astype(dtype))
+ tf_sparsemax_out = tf_sparsemax_op.eval()
+
+ return tf_sparsemax_op, tf_sparsemax_out
+
+ def _tf_sparsemax_loss(self, z, q, dtype, use_gpu):
+ z = z.astype(dtype)
+ q = q.astype(dtype)
+
+ with self.test_session(use_gpu=use_gpu):
+ tf_sparsemax_op = sparsemax(z)
+ tf_loss_op = sparsemax_loss(z, tf_sparsemax_op, q)
+ tf_loss_out = tf_loss_op.eval()
+
+ return tf_loss_op, tf_loss_out
+
+ def _test_sparsemax_loss_against_numpy(self, dtype, random, use_gpu):
+ """check sparsemax-loss kernel against numpy"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ q = np.zeros((test_obs, 10))
+ q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1
+
+ tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
+ np_loss = self._np_sparsemax_loss(z, q).astype(dtype)
+
+ self.assertAllCloseAccordingToType(np_loss, tf_loss_out,
+ half_atol=1e-2, half_rtol=5e-3)
+ self.assertShapeEqual(np_loss, tf_loss_op)
+
+ def _test_constant_add(self, dtype, random, use_gpu):
+ """check sparsemax-loss proposition 3"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ c = random.uniform(low=-3, high=3, size=(test_obs, 1))
+ q = np.zeros((test_obs, 10))
+ q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
+
+ _, tf_loss_zpc = self._tf_sparsemax_loss(
+ z + c, q, dtype, use_gpu
+ )
+
+ _, tf_loss_z = self._tf_sparsemax_loss(
+ z, q, dtype, use_gpu
+ )
+
+ self.assertAllCloseAccordingToType(tf_loss_zpc, tf_loss_z,
+ float_atol=5e-6, float_rtol=5e-6,
+ half_atol=1e-2, half_rtol=1e-2)
+
+ def _test_sparsemax_loss_positive(self, dtype, random, use_gpu):
+ """check sparsemax-loss proposition 4"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ q = np.zeros((test_obs, 10))
+ q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1
+
+ tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
+
+ self.assertAllCloseAccordingToType(np.abs(tf_loss_out), tf_loss_out)
+ self.assertShapeEqual(np.zeros(test_obs), tf_loss_op)
+
+ def _test_sparsemax_loss_zero(self, dtype, random, use_gpu):
+ """check sparsemax-loss proposition 5"""
+ # construct z and q, such that z_k >= 1 + max_{j!=k} z_k holds for
+ # delta_0 = 1.
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ z[:, 0] = np.max(z, axis=1) + 1.05
+
+ q = np.zeros((test_obs, 10))
+ q[:, 0] = 1
+
+ tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
+
+ self.assertAllCloseAccordingToType(np.zeros(test_obs), tf_loss_out)
+ self.assertShapeEqual(np.zeros(test_obs), tf_loss_op)
+
+ self.assertAllCloseAccordingToType(q, tf_sparsemax_out)
+ self.assertShapeEqual(q, tf_sparsemax_op)
+
+ def _test_gradient_against_estimate(self, dtype, random, use_gpu):
+ """check sparsemax-loss Rop, aginst estimated-loss Rop"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
+ q = np.zeros((test_obs, 10)).astype(dtype)
+ q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
+
+ logits = array_ops.placeholder(dtype, name='z')
+ sparsemax_op = sparsemax(logits)
+ loss_op = sparsemax_loss(logits, sparsemax_op, q)
+
+ with self.test_session(use_gpu=use_gpu):
+ err = gradient_checker.compute_gradient_error(
+ logits, z.shape,
+ loss_op, (test_obs, ),
+ x_init_value=z, delta=1e-9
+ )
+
+ self.assertLess(err, 1e-4)
+
+ def _test_gradient_against_numpy(self, dtype, random, use_gpu):
+ """check sparsemax-loss Rop, aginst numpy Rop"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ q = np.zeros((test_obs, 10))
+ q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
+
+ logits = constant_op.constant(z.astype(dtype), name='z')
+ sparsemax_op = sparsemax(logits)
+ loss_op = sparsemax_loss(logits, sparsemax_op, q.astype(dtype))
+ loss_grad_op = gradients_impl.gradients(loss_op, [logits])[0]
+
+ with self.test_session(use_gpu=use_gpu):
+ tf_grad = loss_grad_op.eval()
+ np_grad = self._np_sparsemax_loss_grad(z, q).astype(dtype)
+
+ self.assertAllCloseAccordingToType(np_grad, tf_grad,
+ half_atol=1e-2, half_rtol=5e-3)
+ self.assertShapeEqual(np_grad, loss_grad_op)
+
+ def _test_dtype(self, dtype):
+ random = np.random.RandomState(1)
+
+ self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False)
+
+ self._test_constant_add(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_loss_positive(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_loss_zero(dtype, random, use_gpu=False)
+
+ # sparsemax is not a smooth function so gradient estimation is only
+ # possibol for float64.
+ if dtype == 'float64':
+ self._test_gradient_against_estimate(dtype, random, use_gpu=False)
+
+ self._test_gradient_against_numpy(dtype, random, use_gpu=False)
+
+ def testFloat(self):
+ self._test_dtype('float32')
+
+ def testDouble(self):
+ self._test_dtype('float64')
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
new file mode 100644
index 0000000000..eafac1b9ae
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
@@ -0,0 +1,252 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SparsemaxOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.sparsemax import sparsemax
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+test_obs = 10
+
+
+class SparsemaxTest(test.TestCase):
+
+ def _np_sparsemax(self, z):
+ z = z - np.mean(z, axis=1)[:, np.newaxis]
+
+ # sort z
+ z_sorted = np.sort(z, axis=1)[:, ::-1]
+
+ # calculate k(z)
+ z_cumsum = np.cumsum(z_sorted, axis=1)
+ k = np.arange(1, z.shape[1] + 1)
+ z_check = 1 + k * z_sorted > z_cumsum
+ # use argmax to get the index by row as .nonzero() doesn't
+ # take an axis argument. np.argmax return the first index, but the last
+ # index is required here, use np.flip to get the last index and
+ # `z.shape[axis]` to compensate for np.flip afterwards.
+ k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1)
+
+ # calculate tau(z)
+ tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1]
+ tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1)
+
+ # calculate p
+ return np.maximum(0, z - tau_z)
+
+ def _np_sparsemax_grad(self, z):
+ # chain rule
+ grad = np.ones_like(z)
+
+ # Construct S(z)
+ probability = self._np_sparsemax(z)
+ support = probability > 0
+
+ # Calculate \hat{v}, which will be a vector (scalar for each z)
+ v_hat = np.sum(grad * support, axis=1) / np.sum(support, axis=1)
+
+ # Calculates J(z) * v
+ return support * (grad - v_hat[:, np.newaxis])
+
+ def _tf_sparsemax(self, z, dtype, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ tf_sparsemax_op = sparsemax(z.astype(dtype))
+ tf_sparsemax_out = tf_sparsemax_op.eval()
+
+ return tf_sparsemax_op, tf_sparsemax_out
+
+ def _test_sparsemax_against_numpy(self, dtype, random, use_gpu):
+ """check sparsemax kernel against numpy"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
+ p_sparemax = self._np_sparsemax(z).astype(dtype)
+
+ self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out,
+ half_atol=5e-3)
+ self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
+
+ def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
+ """check sparsemax proposition 1, part 1"""
+ z = np.zeros((1, 10))
+
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
+ p_sparemax = np.ones_like(z, dtype=dtype) / z.size
+
+ self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out)
+ self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
+
+ def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax proposition 1, part 2"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+
+ # assume |A(z)| = 1, as z is continues random
+ z_sort_arg = np.argsort(z, axis=1)[:, ::-1]
+ z_sort = np.sort(z, axis=-1)[:, ::-1]
+ gamma_z = z_sort[:, 0] - z_sort[:, 1]
+ epsilon = (0.99 * gamma_z * 1).reshape(-1, 1)
+
+ # construct the expected 1_A(z) array
+ p_expected = np.zeros((test_obs, 10), dtype=dtype)
+ p_expected[np.arange(0, test_obs), z_sort_arg[:, 0]] = 1
+
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
+ (1 / epsilon) * z, dtype, use_gpu
+ )
+
+ self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out)
+ self.assertShapeEqual(p_expected, tf_sparsemax_op)
+
+ def _test_constant_add(self, dtype, random, use_gpu):
+ """check sparsemax proposition 2"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
+ c = random.uniform(low=-3, high=3, size=(test_obs, 1)).astype(dtype)
+
+ _, tf_sparsemax_zpc = self._tf_sparsemax(
+ z + c, dtype, use_gpu
+ )
+
+ _, tf_sparsemax_z = self._tf_sparsemax(
+ z, dtype, use_gpu
+ )
+
+ self.assertAllCloseAccordingToType(tf_sparsemax_zpc, tf_sparsemax_z,
+ half_atol=5e-3)
+
+ def _test_permutation(self, dtype, random, use_gpu):
+ """check sparsemax proposition 3"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ _, p = self._tf_sparsemax(z, dtype, use_gpu)
+
+ for i in range(test_obs):
+ per = random.permutation(10)
+
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
+ z[i, per].reshape(1, -1), dtype, use_gpu
+ )
+ p_expected = p[i, per].reshape(1, -1)
+
+ self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out,
+ half_atol=5e-3)
+ self.assertShapeEqual(p_expected, tf_sparsemax_op)
+
+ def _test_diffrence(self, dtype, random, use_gpu):
+ """check sparsemax proposition 4"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ _, p = self._tf_sparsemax(z, dtype, use_gpu)
+
+ etol = {'float16': 1e-2, 'float32': 1e-6, 'float64': 1e-9}[dtype]
+
+ for val in range(0, test_obs):
+ for i in range(0, 10):
+ for j in range(0, 10):
+ # check condition, the obesite pair will be checked anyway
+ if z[val, i] > z[val, j]:
+ continue
+
+ self.assertTrue(
+ 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol,
+ "0 <= %.10f <= %.10f" % (
+ p[val, j] - p[val, i], z[val, j] - z[val, i] + etol
+ )
+ )
+
+ def _test_two_dimentional(self, dtype, random, use_gpu):
+ """check two dimentation sparsemax case"""
+ t = np.linspace(-2, 2, test_obs, dtype=dtype)
+ z = np.vstack([
+ t, np.zeros(test_obs, dtype=dtype)
+ ]).T
+
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
+
+ p0_expected = np.select([t < -1, t <= 1, t > 1], [0, (t + 1) / 2, 1])
+
+ self.assertAllCloseAccordingToType(p0_expected, tf_sparsemax_out[:, 0])
+ self.assertAllCloseAccordingToType(1 - p0_expected, tf_sparsemax_out[:, 1])
+ self.assertShapeEqual(z, tf_sparsemax_op)
+
+ def _test_gradient_against_estimate(self, dtype, random, use_gpu):
+ """check sparsemax Rop, aginst estimated Rop"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
+
+ logits = array_ops.placeholder(dtype, name='z')
+ sparsemax_op = sparsemax(logits)
+
+ with self.test_session(use_gpu=use_gpu):
+ err = gradient_checker.compute_gradient_error(
+ logits, z.shape,
+ sparsemax_op, z.shape,
+ x_init_value=z, delta=1e-9
+ )
+
+ self.assertLess(err, 1e-4)
+
+ def _test_gradient_against_numpy(self, dtype, random, use_gpu):
+ """check sparsemax Rop, aginst numpy Rop"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
+
+ logits = constant_op.constant(z, name='z')
+ sparsemax_op = sparsemax(logits)
+ sparsemax_grad_op = gradients_impl.gradients(sparsemax_op, [logits])[0]
+
+ with self.test_session(use_gpu=use_gpu):
+ tf_grad = sparsemax_grad_op.eval()
+ np_grad = self._np_sparsemax_grad(z)
+
+ self.assertAllCloseAccordingToType(np_grad, tf_grad)
+ self.assertShapeEqual(np_grad, sparsemax_grad_op)
+
+ def _test_dtype(self, dtype):
+ random = np.random.RandomState(1)
+
+ self._test_sparsemax_against_numpy(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_of_inf(dtype, random, use_gpu=False)
+
+ self._test_constant_add(dtype, random, use_gpu=False)
+
+ self._test_permutation(dtype, random, use_gpu=False)
+
+ self._test_diffrence(dtype, random, use_gpu=False)
+
+ self._test_two_dimentional(dtype, random, use_gpu=False)
+
+ # sparsemax is not a smooth function so gradient estimation is only
+ # possibol for float64.
+ if dtype == 'float64':
+ self._test_gradient_against_estimate(dtype, random, use_gpu=False)
+
+ self._test_gradient_against_numpy(dtype, random, use_gpu=False)
+
+ def testFloat(self):
+ self._test_dtype('float32')
+
+ def testDouble(self):
+ self._test_dtype('float64')
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
new file mode 100644
index 0000000000..6e1cd75f22
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
@@ -0,0 +1,74 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sparsemax op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.framework import ops, dtypes
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn
+
+
+def sparsemax(logits, name=None):
+ """Computes sparsemax activations [1].
+
+ For each batch `i` and class `j` we have
+ sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)
+
+ [1]: https://arxiv.org/abs/1602.02068
+
+ Args:
+ logits: A `Tensor`. Must be one of the following types: `half`, `float32`,
+ `float64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `logits`.
+ """
+
+ with ops.name_scope(name, "sparsemax", [logits]) as name:
+ logits = ops.convert_to_tensor(logits, name="logits")
+ obs = array_ops.shape(logits)[0]
+ dims = array_ops.shape(logits)[1]
+
+ z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+
+ # sort z
+ z_sorted, _ = nn.top_k(z, k=dims)
+
+ # calculate k(z)
+ z_cumsum = math_ops.cumsum(z_sorted, axis=1)
+ k = math_ops.range(
+ 1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype
+ )
+ z_check = 1 + k * z_sorted > z_cumsum
+ # because the z_check vector is always [1,1,...1,0,0,...0] finding the
+ # (index + 1) of the last `1` is the same as just summing the number of 1.
+ k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1)
+
+ # calculate tau(z)
+ indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1)
+ tau_sum = array_ops.gather_nd(z_cumsum, indices)
+ tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype)
+
+ # calculate p
+ return math_ops.maximum(
+ math_ops.cast(0, logits.dtype),
+ z - tau_z[:, array_ops.newaxis]
+ )
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
new file mode 100644
index 0000000000..1f5e8c37e3
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
@@ -0,0 +1,59 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sparsemax Loss op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+def sparsemax_loss(logits, sparsemax, labels, name=None):
+ """Computes sparsemax loss function [1].
+
+ [1]: https://arxiv.org/abs/1602.02068
+
+ Args:
+ logits: A `Tensor`. Must be one of the following types: `half`, `float32`,
+ `float64`.
+ sparsemax: A `Tensor`. Must have the same type as `logits`.
+ labels: A `Tensor`. Must have the same type as `logits`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `logits`.
+ """
+
+ with ops.name_scope(name, "sparsemax_loss",
+ [logits, sparsemax, labels]) as name:
+ logits = ops.convert_to_tensor(logits, name="logits")
+ sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax")
+ labels = ops.convert_to_tensor(labels, name="labels")
+
+ shifted_logits = logits - \
+ math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+
+ # sum over support
+ support = math_ops.cast(sparsemax > 0, sparsemax.dtype)
+ sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax)
+
+ # - z_k + ||q||^2
+ q_part = labels * (0.5 * labels - shifted_logits)
+
+ return math_ops.reduce_sum(sum_s + q_part, axis=1)