aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-02-08 09:25:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-08 09:50:05 -0800
commit639b4e71f532761a4840b1cdbaea55ad0917c75b (patch)
tree5116415b1d9ff82f054dd4feeadd81cb833d6435 /tensorflow/contrib
parent15ff7b702788c0cf75bb8d5ce090f06490098cf7 (diff)
Merge changes from github.
Change: 146918929
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt4
-rw-r--r--tensorflow/contrib/cmake/tf_cc_ops.cmake3
-rw-r--r--tensorflow/contrib/cmake/tf_python.cmake3
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm.py7
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_ops.py7
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_test.py10
-rw-r--r--tensorflow/contrib/learn/BUILD1
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py3
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py20
-rw-r--r--tensorflow/contrib/learn/python/learn/models.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py2
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops.py1
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py52
-rw-r--r--tensorflow/contrib/nccl/BUILD2
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py2
-rw-r--r--tensorflow/contrib/slim/README.md2
-rw-r--r--tensorflow/contrib/sparsemax/BUILD76
-rw-r--r--tensorflow/contrib/sparsemax/__init__.py30
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py224
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py252
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax.py74
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py59
24 files changed, 828 insertions, 12 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 680053ae18..d1d8b19d69 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -47,6 +47,7 @@ py_library(
"//tensorflow/contrib/slim",
"//tensorflow/contrib/slim:nets",
"//tensorflow/contrib/solvers:solvers_py",
+ "//tensorflow/contrib/sparsemax:sparsemax_py",
"//tensorflow/contrib/specs",
"//tensorflow/contrib/stat_summarizer:stat_summarizer_py",
"//tensorflow/contrib/tensor_forest:init_py",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 9404b7a146..fede580f0f 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -49,6 +49,7 @@ from tensorflow.contrib import rnn
from tensorflow.contrib import seq2seq
from tensorflow.contrib import slim
from tensorflow.contrib import solvers
+from tensorflow.contrib import sparsemax
from tensorflow.contrib import stat_summarizer
from tensorflow.contrib import tensor_forest
from tensorflow.contrib import tensorboard
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index 64262fdce5..68929da5c9 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -170,7 +170,8 @@ if (tensorflow_ENABLE_GPU)
# add cudnn
include_directories(${CUDNN_HOME})
- set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDNN_HOME}/lib/x64/cudnn.lib)
+ set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES}
+ ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib)
# create cuda_config.h
FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h
@@ -179,6 +180,7 @@ if (tensorflow_ENABLE_GPU)
"#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n"
"#define TF_CUDA_VERSION \"64_80\"\n"
"#define TF_CUDNN_VERSION \"64_5\"\n"
+ "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n"
"#endif // CUDA_CUDA_CONFIG_H_\n"
)
diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake
index 6eaa2502be..bca700aca2 100644
--- a/tensorflow/contrib/cmake/tf_cc_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake
@@ -71,7 +71,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names})
COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal}
DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir
)
-
+
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h)
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc)
list(APPEND tf_cc_ops_generated_files ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h)
@@ -79,6 +79,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names})
endforeach()
+
########################################################
# tf_cc_ops library
########################################################
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 7717cf7b71..9ab6f176c7 100644
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -372,6 +372,9 @@ add_python_module("tensorflow/contrib/slim/python/slim/nets")
add_python_module("tensorflow/contrib/solvers")
add_python_module("tensorflow/contrib/solvers/python")
add_python_module("tensorflow/contrib/solvers/python/ops")
+add_python_module("tensorflow/contrib/sparsemax")
+add_python_module("tensorflow/contrib/sparsemax/python")
+add_python_module("tensorflow/contrib/sparsemax/python/ops")
add_python_module("tensorflow/contrib/specs")
add_python_module("tensorflow/contrib/specs/python")
add_python_module("tensorflow/contrib/stat_summarizer")
diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py
index eddce45c88..72d01fbb2a 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm.py
@@ -102,7 +102,12 @@ class GMM(estimator.Estimator):
results = self.evaluate(input_fn=input_fn, batch_size=batch_size,
steps=steps)
return np.sum(results[GMM.SCORES])
-
+
+ def weights(self):
+ """Returns the cluster weights."""
+ return checkpoint_utils.load_variable(
+ self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
+
def clusters(self):
"""Returns cluster centers."""
clusters = checkpoint_utils.load_variable(
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py
index e795c0aac7..fbf7afc125 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py
@@ -92,6 +92,7 @@ def _init_clusters_random(data, num_clusters, random_seed):
class GmmAlgorithm(object):
"""Tensorflow Gaussian mixture model clustering class."""
+ CLUSTERS_WEIGHT = 'alphas'
CLUSTERS_VARIABLE = 'clusters'
CLUSTERS_COVS_VARIABLE = 'clusters_covs'
@@ -187,11 +188,13 @@ class GmmAlgorithm(object):
array_ops.expand_dims(array_ops.diag_part(cov), 0),
[self._num_classes, 1])
self._covs = variables.Variable(
- covs, name='clusters_covs', validate_shape=False)
+ covs, name=self.CLUSTERS_COVS_VARIABLE, validate_shape=False)
# Mixture weights, representing the probability that a randomly
# selected unobservable data (in EM terms) was generated by component k.
self._alpha = variables.Variable(
- array_ops.tile([1.0 / self._num_classes], [self._num_classes]))
+ array_ops.tile([1.0 / self._num_classes], [self._num_classes]),
+ name=self.CLUSTERS_WEIGHT,
+ validate_shape=False)
def training_ops(self):
"""Returns the training operation."""
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py
index 1452c90072..c951a6981f 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_test.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py
@@ -109,6 +109,16 @@ class GMMTest(test.TestCase):
np.linalg.inv(covs[assignments[r]])), points[r, :] -
means[assignments[r]])))
return (points, assignments, scores)
+
+ def test_weights(self):
+ """Tests the shape of the weights."""
+ gmm = gmm_lib.GMM(self.num_centers,
+ initial_clusters=self.initial_means,
+ random_seed=4,
+ config=run_config.RunConfig(tf_random_seed=2))
+ gmm.fit(input_fn=self.input_fn(), steps=0)
+ weights = gmm.weights()
+ self.assertAllEqual(list(weights.shape), [self.num_centers])
def test_clusters(self):
"""Tests the shape of the clusters."""
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 89b9245172..e236f03018 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -480,6 +480,7 @@ py_test(
size = "medium",
srcs = ["python/learn/estimators/estimator_test.py"],
srcs_version = "PY2AND3",
+ tags = ["manual"],
deps = [
":learn",
"//tensorflow/contrib/framework:framework_py",
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 96802a570c..d1113678a9 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -191,6 +191,9 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
if not dnn_feature_columns:
dnn_logits = None
else:
+ if not dnn_hidden_units:
+ raise ValueError(
+ "dnn_hidden_units must be defined when dnn_feature_columns is specified.")
dnn_partitioner = (
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas))
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
index cdab569c65..01e14c32e5 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py
@@ -241,6 +241,26 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
dnn_feature_columns=None,
dnn_hidden_units=[3, 3])
+ def testNoDnnHiddenUnits(self):
+ def _input_fn():
+ return {
+ 'age':
+ constant_op.constant([1]),
+ 'language':
+ sparse_tensor.SparseTensor(
+ values=['english'], indices=[[0, 0]], dense_shape=[1, 1])
+ }, constant_op.constant([[1]])
+
+ language = feature_column.sparse_column_with_hash_bucket('language', 100)
+ age = feature_column.real_valued_column('age')
+
+ with self.assertRaisesRegexp(
+ ValueError,
+ 'dnn_hidden_units must be defined when dnn_feature_columns is specified'):
+ classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
+ dnn_feature_columns=[age, language])
+ classifier.fit(input_fn=_input_fn, steps=2)
+
def testEmbeddingMultiplier(self):
embedding_language = feature_column.embedding_column(
feature_column.sparse_column_with_hash_bucket('language', 10),
diff --git a/tensorflow/contrib/learn/python/learn/models.py b/tensorflow/contrib/learn/python/learn/models.py
index e2af0fa7b6..234605ff76 100644
--- a/tensorflow/contrib/learn/python/learn/models.py
+++ b/tensorflow/contrib/learn/python/learn/models.py
@@ -274,10 +274,10 @@ def bidirectional_rnn(cell_fw,
output_bw = _reverse_seq(tmp, sequence_length)
# Concat each of the forward/backward outputs
outputs = [
- array_ops_.concat_v2([fw, bw], 1) for fw, bw in zip(output_fw, output_bw)
+ array_ops_.concat([fw, bw], 1) for fw, bw in zip(output_fw, output_bw)
]
- return outputs, array_ops_.concat_v2([state_fw, state_bw], 1)
+ return outputs, array_ops_.concat([state_fw, state_bw], 1)
# End of TensorFlow 0.7
diff --git a/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py b/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py
index 5ac9bfd808..fa3b7323e3 100644
--- a/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py
+++ b/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py
@@ -59,7 +59,7 @@ def embedding_lookup(params, ids, name='embedding_lookup'):
ids_flat = array_ops_.reshape(
ids, math_ops.reduce_prod(shape, keep_dims=True))
embeds_flat = nn.embedding_lookup(params, ids_flat, name)
- embed_shape = array_ops_.concat_v2([shape, [-1]], 0)
+ embed_shape = array_ops_.concat([shape, [-1]], 0)
embeds = array_ops_.reshape(embeds_flat, embed_shape)
embeds.set_shape(ids.get_shape().concatenate(params.get_shape()[1:]))
return embeds
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py
index 1e4fb58945..5ca8c8a18b 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops.py
@@ -427,7 +427,6 @@ def sparse_softmax_cross_entropy(logits, labels, weights=1.0, scope=None):
with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
[logits, labels, weights]) as scope:
labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
- weights = array_ops.squeeze(weights)
losses = nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
logits=logits,
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 94b8dfca57..81a4aaba2b 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -243,6 +243,34 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
expected_value = 400.0 * label_smoothing / 3.0
self.assertAlmostEqual(loss.eval(), expected_value, 3)
+ def testLossWithDynamicallyShapedWeights1D(self):
+ logits = constant_op.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = constant_op.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ weights = [2.3, 2.4, 2.5]
+ weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
+ loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
+ with self.test_session() as sess:
+ loss = sess.run(loss, {weights_placeholder: weights})
+ self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
+
+ def testLossWithDynamicallyShapedWeights2D(self):
+ logits = constant_op.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = constant_op.constant([[0, 0, 1],
+ [1, 0, 0],
+ [0, 1, 0]])
+ weights = [[2.3], [2.4], [2.5]]
+ weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None])
+ loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
+ with self.test_session() as sess:
+ loss = sess.run(loss, {weights_placeholder: weights})
+ self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
+
class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
@@ -445,6 +473,30 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights=weights).eval()
+ def testLossWithDynamicallyShapedWeights1D(self):
+ logits = constant_op.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = constant_op.constant([2, 0, 1])
+ weights = [2.3, 2.4, 2.5]
+ weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
+ loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder)
+ with self.test_session() as sess:
+ loss = sess.run(loss, {weights_placeholder: weights})
+ self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
+
+ def testLossWithDynamicallyShapedWeights2D(self):
+ logits = constant_op.constant([[10.0, 0.0, 0.0],
+ [0.0, 10.0, 0.0],
+ [0.0, 0.0, 10.0]])
+ labels = constant_op.constant([2, 0, 1])
+ weights = [[2.3], [2.4], [2.5]]
+ weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None, None])
+ loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights_placeholder)
+ with self.test_session() as sess:
+ loss = sess.run(loss, {weights_placeholder: weights})
+ self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
+
class SigmoidCrossEntropyLossTest(test.TestCase):
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index a6db4bdd36..c7f32baa2d 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -84,7 +84,7 @@ cuda_py_test(
tf_cuda_cc_test(
name = "nccl_manager_test",
- size = "small",
+ size = "medium",
srcs = if_cuda(
[
"kernels/nccl_manager.cc",
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index fd46230448..19b5788f2d 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -95,7 +95,7 @@ class RNNCellTest(test.TestCase):
input_size = 4
feature_size = 2
frequency_skip = 1
- num_shifts = (input_size - feature_size) / frequency_skip + 1
+ num_shifts = (input_size - feature_size) // frequency_skip + 1
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([batch_size, input_size])
diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md
index 1c192076ce..bcc641e04a 100644
--- a/tensorflow/contrib/slim/README.md
+++ b/tensorflow/contrib/slim/README.md
@@ -880,7 +880,7 @@ names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
# Create the summary ops such that they also print out to std output:
summary_ops = []
-for metric_name, metric_value in metrics_to_values.iteritems():
+for metric_name, metric_value in names_to_values.iteritems():
op = tf.summary.scalar(metric_name, metric_value)
op = tf.Print(op, [metric_value], metric_name)
summary_ops.append(op)
diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD
new file mode 100644
index 0000000000..bd59c626f2
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/BUILD
@@ -0,0 +1,76 @@
+# Description:
+# Contains ops to train linear models on top of TensorFlow.
+# APIs here are meant to evolve over time.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package(default_visibility = ["//visibility:public"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_library",
+ "tf_py_test",
+)
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_kernel_tests_linkstatic",
+)
+
+py_library(
+ name = "sparsemax_py",
+ srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
+ ],
+)
+
+cuda_py_tests(
+ name = "sparsemax_test",
+ size = "small",
+ srcs = ["python/kernel_tests/sparsemax_test.py"],
+ additional_deps = [
+ ":sparsemax_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_tests(
+ name = "sparsemax_loss_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/sparsemax_loss_test.py"],
+ additional_deps = [
+ ":sparsemax_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/sparsemax/__init__.py b/tensorflow/contrib/sparsemax/__init__.py
new file mode 100644
index 0000000000..0be4988dbf
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/__init__.py
@@ -0,0 +1,30 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Module that implements sparsemax and sparsemax loss, see [1].
+
+[1] https://arxiv.org/abs/1602.02068
+
+## Sparsemax
+
+@@sparsemax
+@@sparsemax_loss
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.sparsemax.python.ops.sparsemax import sparsemax
+from tensorflow.contrib.sparsemax.python.ops.sparsemax_loss \
+ import sparsemax_loss
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
new file mode 100644
index 0000000000..89dbcd96f8
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
@@ -0,0 +1,224 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SparsemaxLossOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.sparsemax import sparsemax, sparsemax_loss
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+test_obs = 10
+
+
+class SparsemaxLossTest(test.TestCase):
+
+ def _np_sparsemax(self, z):
+ z = z - np.mean(z, axis=1)[:, np.newaxis]
+
+ # sort z
+ z_sorted = np.sort(z, axis=1)[:, ::-1]
+
+ # calculate k(z)
+ z_cumsum = np.cumsum(z_sorted, axis=1)
+ k = np.arange(1, z.shape[1] + 1)
+ z_check = 1 + k * z_sorted > z_cumsum
+ # use argmax to get the index by row as .nonzero() doesn't
+ # take an axis argument. np.argmax return the first index, but the last
+ # index is required here, use np.flip to get the last index and
+ # `z.shape[axis]` to compensate for np.flip afterwards.
+ k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1)
+
+ # calculate tau(z)
+ tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1]
+ tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1)
+
+ # calculate p
+ return np.maximum(0, z - tau_z)
+
+ def _np_sparsemax_loss(self, z, q):
+ z = z - np.mean(z, axis=1)[:, np.newaxis]
+
+ # Calculate q^T * z
+ z_k = np.sum(q * z, axis=1)
+
+ # calculate sum over S(z)
+ p = self._np_sparsemax(z)
+ s = p > 0
+ # z_i^2 - tau(z)^2 = p_i (2 * z_i - p_i) for i \in S(z)
+ S_sum = np.sum(s * p * (2 * z - p), axis=1)
+
+ # because q is binary, sum([q_1^2, q_2^2, ...]) is just sum(q)
+ q_norm = np.sum(q, axis=1)
+
+ return -z_k + 0.5 * S_sum + 0.5 * q_norm
+
+ def _np_sparsemax_loss_grad(self, z, q):
+ # chain rule
+ grad = 1
+
+ return grad * (-q + self._np_sparsemax(z))
+
+ def _tf_sparsemax(self, z, dtype, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ tf_sparsemax_op = sparsemax(z.astype(dtype))
+ tf_sparsemax_out = tf_sparsemax_op.eval()
+
+ return tf_sparsemax_op, tf_sparsemax_out
+
+ def _tf_sparsemax_loss(self, z, q, dtype, use_gpu):
+ z = z.astype(dtype)
+ q = q.astype(dtype)
+
+ with self.test_session(use_gpu=use_gpu):
+ tf_sparsemax_op = sparsemax(z)
+ tf_loss_op = sparsemax_loss(z, tf_sparsemax_op, q)
+ tf_loss_out = tf_loss_op.eval()
+
+ return tf_loss_op, tf_loss_out
+
+ def _test_sparsemax_loss_against_numpy(self, dtype, random, use_gpu):
+ """check sparsemax-loss kernel against numpy"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ q = np.zeros((test_obs, 10))
+ q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1
+
+ tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
+ np_loss = self._np_sparsemax_loss(z, q).astype(dtype)
+
+ self.assertAllCloseAccordingToType(np_loss, tf_loss_out,
+ half_atol=1e-2, half_rtol=5e-3)
+ self.assertShapeEqual(np_loss, tf_loss_op)
+
+ def _test_constant_add(self, dtype, random, use_gpu):
+ """check sparsemax-loss proposition 3"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ c = random.uniform(low=-3, high=3, size=(test_obs, 1))
+ q = np.zeros((test_obs, 10))
+ q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
+
+ _, tf_loss_zpc = self._tf_sparsemax_loss(
+ z + c, q, dtype, use_gpu
+ )
+
+ _, tf_loss_z = self._tf_sparsemax_loss(
+ z, q, dtype, use_gpu
+ )
+
+ self.assertAllCloseAccordingToType(tf_loss_zpc, tf_loss_z,
+ float_atol=5e-6, float_rtol=5e-6,
+ half_atol=1e-2, half_rtol=1e-2)
+
+ def _test_sparsemax_loss_positive(self, dtype, random, use_gpu):
+ """check sparsemax-loss proposition 4"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ q = np.zeros((test_obs, 10))
+ q[np.arange(0, test_obs), random.randint(0, 10, size=test_obs)] = 1
+
+ tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
+
+ self.assertAllCloseAccordingToType(np.abs(tf_loss_out), tf_loss_out)
+ self.assertShapeEqual(np.zeros(test_obs), tf_loss_op)
+
+ def _test_sparsemax_loss_zero(self, dtype, random, use_gpu):
+ """check sparsemax-loss proposition 5"""
+ # construct z and q, such that z_k >= 1 + max_{j!=k} z_k holds for
+ # delta_0 = 1.
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ z[:, 0] = np.max(z, axis=1) + 1.05
+
+ q = np.zeros((test_obs, 10))
+ q[:, 0] = 1
+
+ tf_loss_op, tf_loss_out = self._tf_sparsemax_loss(z, q, dtype, use_gpu)
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
+
+ self.assertAllCloseAccordingToType(np.zeros(test_obs), tf_loss_out)
+ self.assertShapeEqual(np.zeros(test_obs), tf_loss_op)
+
+ self.assertAllCloseAccordingToType(q, tf_sparsemax_out)
+ self.assertShapeEqual(q, tf_sparsemax_op)
+
+ def _test_gradient_against_estimate(self, dtype, random, use_gpu):
+ """check sparsemax-loss Rop, aginst estimated-loss Rop"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
+ q = np.zeros((test_obs, 10)).astype(dtype)
+ q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
+
+ logits = array_ops.placeholder(dtype, name='z')
+ sparsemax_op = sparsemax(logits)
+ loss_op = sparsemax_loss(logits, sparsemax_op, q)
+
+ with self.test_session(use_gpu=use_gpu):
+ err = gradient_checker.compute_gradient_error(
+ logits, z.shape,
+ loss_op, (test_obs, ),
+ x_init_value=z, delta=1e-9
+ )
+
+ self.assertLess(err, 1e-4)
+
+ def _test_gradient_against_numpy(self, dtype, random, use_gpu):
+ """check sparsemax-loss Rop, aginst numpy Rop"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ q = np.zeros((test_obs, 10))
+ q[np.arange(0, test_obs), np.random.randint(0, 10, size=test_obs)] = 1
+
+ logits = constant_op.constant(z.astype(dtype), name='z')
+ sparsemax_op = sparsemax(logits)
+ loss_op = sparsemax_loss(logits, sparsemax_op, q.astype(dtype))
+ loss_grad_op = gradients_impl.gradients(loss_op, [logits])[0]
+
+ with self.test_session(use_gpu=use_gpu):
+ tf_grad = loss_grad_op.eval()
+ np_grad = self._np_sparsemax_loss_grad(z, q).astype(dtype)
+
+ self.assertAllCloseAccordingToType(np_grad, tf_grad,
+ half_atol=1e-2, half_rtol=5e-3)
+ self.assertShapeEqual(np_grad, loss_grad_op)
+
+ def _test_dtype(self, dtype):
+ random = np.random.RandomState(1)
+
+ self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False)
+
+ self._test_constant_add(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_loss_positive(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_loss_zero(dtype, random, use_gpu=False)
+
+ # sparsemax is not a smooth function so gradient estimation is only
+ # possibol for float64.
+ if dtype == 'float64':
+ self._test_gradient_against_estimate(dtype, random, use_gpu=False)
+
+ self._test_gradient_against_numpy(dtype, random, use_gpu=False)
+
+ def testFloat(self):
+ self._test_dtype('float32')
+
+ def testDouble(self):
+ self._test_dtype('float64')
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
new file mode 100644
index 0000000000..eafac1b9ae
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
@@ -0,0 +1,252 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SparsemaxOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.sparsemax import sparsemax
+from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+test_obs = 10
+
+
+class SparsemaxTest(test.TestCase):
+
+ def _np_sparsemax(self, z):
+ z = z - np.mean(z, axis=1)[:, np.newaxis]
+
+ # sort z
+ z_sorted = np.sort(z, axis=1)[:, ::-1]
+
+ # calculate k(z)
+ z_cumsum = np.cumsum(z_sorted, axis=1)
+ k = np.arange(1, z.shape[1] + 1)
+ z_check = 1 + k * z_sorted > z_cumsum
+ # use argmax to get the index by row as .nonzero() doesn't
+ # take an axis argument. np.argmax return the first index, but the last
+ # index is required here, use np.flip to get the last index and
+ # `z.shape[axis]` to compensate for np.flip afterwards.
+ k_z = z.shape[1] - np.argmax(z_check[:, ::-1], axis=1)
+
+ # calculate tau(z)
+ tau_sum = z_cumsum[np.arange(0, z.shape[0]), k_z - 1]
+ tau_z = ((tau_sum - 1) / k_z).reshape(-1, 1)
+
+ # calculate p
+ return np.maximum(0, z - tau_z)
+
+ def _np_sparsemax_grad(self, z):
+ # chain rule
+ grad = np.ones_like(z)
+
+ # Construct S(z)
+ probability = self._np_sparsemax(z)
+ support = probability > 0
+
+ # Calculate \hat{v}, which will be a vector (scalar for each z)
+ v_hat = np.sum(grad * support, axis=1) / np.sum(support, axis=1)
+
+ # Calculates J(z) * v
+ return support * (grad - v_hat[:, np.newaxis])
+
+ def _tf_sparsemax(self, z, dtype, use_gpu):
+ with self.test_session(use_gpu=use_gpu):
+ tf_sparsemax_op = sparsemax(z.astype(dtype))
+ tf_sparsemax_out = tf_sparsemax_op.eval()
+
+ return tf_sparsemax_op, tf_sparsemax_out
+
+ def _test_sparsemax_against_numpy(self, dtype, random, use_gpu):
+ """check sparsemax kernel against numpy"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
+ p_sparemax = self._np_sparsemax(z).astype(dtype)
+
+ self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out,
+ half_atol=5e-3)
+ self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
+
+ def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
+ """check sparsemax proposition 1, part 1"""
+ z = np.zeros((1, 10))
+
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
+ p_sparemax = np.ones_like(z, dtype=dtype) / z.size
+
+ self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out)
+ self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
+
+ def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax proposition 1, part 2"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+
+ # assume |A(z)| = 1, as z is continues random
+ z_sort_arg = np.argsort(z, axis=1)[:, ::-1]
+ z_sort = np.sort(z, axis=-1)[:, ::-1]
+ gamma_z = z_sort[:, 0] - z_sort[:, 1]
+ epsilon = (0.99 * gamma_z * 1).reshape(-1, 1)
+
+ # construct the expected 1_A(z) array
+ p_expected = np.zeros((test_obs, 10), dtype=dtype)
+ p_expected[np.arange(0, test_obs), z_sort_arg[:, 0]] = 1
+
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
+ (1 / epsilon) * z, dtype, use_gpu
+ )
+
+ self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out)
+ self.assertShapeEqual(p_expected, tf_sparsemax_op)
+
+ def _test_constant_add(self, dtype, random, use_gpu):
+ """check sparsemax proposition 2"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
+ c = random.uniform(low=-3, high=3, size=(test_obs, 1)).astype(dtype)
+
+ _, tf_sparsemax_zpc = self._tf_sparsemax(
+ z + c, dtype, use_gpu
+ )
+
+ _, tf_sparsemax_z = self._tf_sparsemax(
+ z, dtype, use_gpu
+ )
+
+ self.assertAllCloseAccordingToType(tf_sparsemax_zpc, tf_sparsemax_z,
+ half_atol=5e-3)
+
+ def _test_permutation(self, dtype, random, use_gpu):
+ """check sparsemax proposition 3"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ _, p = self._tf_sparsemax(z, dtype, use_gpu)
+
+ for i in range(test_obs):
+ per = random.permutation(10)
+
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(
+ z[i, per].reshape(1, -1), dtype, use_gpu
+ )
+ p_expected = p[i, per].reshape(1, -1)
+
+ self.assertAllCloseAccordingToType(p_expected, tf_sparsemax_out,
+ half_atol=5e-3)
+ self.assertShapeEqual(p_expected, tf_sparsemax_op)
+
+ def _test_diffrence(self, dtype, random, use_gpu):
+ """check sparsemax proposition 4"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10))
+ _, p = self._tf_sparsemax(z, dtype, use_gpu)
+
+ etol = {'float16': 1e-2, 'float32': 1e-6, 'float64': 1e-9}[dtype]
+
+ for val in range(0, test_obs):
+ for i in range(0, 10):
+ for j in range(0, 10):
+ # check condition, the obesite pair will be checked anyway
+ if z[val, i] > z[val, j]:
+ continue
+
+ self.assertTrue(
+ 0 <= p[val, j] - p[val, i] <= z[val, j] - z[val, i] + etol,
+ "0 <= %.10f <= %.10f" % (
+ p[val, j] - p[val, i], z[val, j] - z[val, i] + etol
+ )
+ )
+
+ def _test_two_dimentional(self, dtype, random, use_gpu):
+ """check two dimentation sparsemax case"""
+ t = np.linspace(-2, 2, test_obs, dtype=dtype)
+ z = np.vstack([
+ t, np.zeros(test_obs, dtype=dtype)
+ ]).T
+
+ tf_sparsemax_op, tf_sparsemax_out = self._tf_sparsemax(z, dtype, use_gpu)
+
+ p0_expected = np.select([t < -1, t <= 1, t > 1], [0, (t + 1) / 2, 1])
+
+ self.assertAllCloseAccordingToType(p0_expected, tf_sparsemax_out[:, 0])
+ self.assertAllCloseAccordingToType(1 - p0_expected, tf_sparsemax_out[:, 1])
+ self.assertShapeEqual(z, tf_sparsemax_op)
+
+ def _test_gradient_against_estimate(self, dtype, random, use_gpu):
+ """check sparsemax Rop, aginst estimated Rop"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
+
+ logits = array_ops.placeholder(dtype, name='z')
+ sparsemax_op = sparsemax(logits)
+
+ with self.test_session(use_gpu=use_gpu):
+ err = gradient_checker.compute_gradient_error(
+ logits, z.shape,
+ sparsemax_op, z.shape,
+ x_init_value=z, delta=1e-9
+ )
+
+ self.assertLess(err, 1e-4)
+
+ def _test_gradient_against_numpy(self, dtype, random, use_gpu):
+ """check sparsemax Rop, aginst numpy Rop"""
+ z = random.uniform(low=-3, high=3, size=(test_obs, 10)).astype(dtype)
+
+ logits = constant_op.constant(z, name='z')
+ sparsemax_op = sparsemax(logits)
+ sparsemax_grad_op = gradients_impl.gradients(sparsemax_op, [logits])[0]
+
+ with self.test_session(use_gpu=use_gpu):
+ tf_grad = sparsemax_grad_op.eval()
+ np_grad = self._np_sparsemax_grad(z)
+
+ self.assertAllCloseAccordingToType(np_grad, tf_grad)
+ self.assertShapeEqual(np_grad, sparsemax_grad_op)
+
+ def _test_dtype(self, dtype):
+ random = np.random.RandomState(1)
+
+ self._test_sparsemax_against_numpy(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_of_inf(dtype, random, use_gpu=False)
+
+ self._test_constant_add(dtype, random, use_gpu=False)
+
+ self._test_permutation(dtype, random, use_gpu=False)
+
+ self._test_diffrence(dtype, random, use_gpu=False)
+
+ self._test_two_dimentional(dtype, random, use_gpu=False)
+
+ # sparsemax is not a smooth function so gradient estimation is only
+ # possibol for float64.
+ if dtype == 'float64':
+ self._test_gradient_against_estimate(dtype, random, use_gpu=False)
+
+ self._test_gradient_against_numpy(dtype, random, use_gpu=False)
+
+ def testFloat(self):
+ self._test_dtype('float32')
+
+ def testDouble(self):
+ self._test_dtype('float64')
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
new file mode 100644
index 0000000000..6e1cd75f22
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
@@ -0,0 +1,74 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sparsemax op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.framework import ops, dtypes
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn
+
+
+def sparsemax(logits, name=None):
+ """Computes sparsemax activations [1].
+
+ For each batch `i` and class `j` we have
+ sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)
+
+ [1]: https://arxiv.org/abs/1602.02068
+
+ Args:
+ logits: A `Tensor`. Must be one of the following types: `half`, `float32`,
+ `float64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `logits`.
+ """
+
+ with ops.name_scope(name, "sparsemax", [logits]) as name:
+ logits = ops.convert_to_tensor(logits, name="logits")
+ obs = array_ops.shape(logits)[0]
+ dims = array_ops.shape(logits)[1]
+
+ z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+
+ # sort z
+ z_sorted, _ = nn.top_k(z, k=dims)
+
+ # calculate k(z)
+ z_cumsum = math_ops.cumsum(z_sorted, axis=1)
+ k = math_ops.range(
+ 1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype
+ )
+ z_check = 1 + k * z_sorted > z_cumsum
+ # because the z_check vector is always [1,1,...1,0,0,...0] finding the
+ # (index + 1) of the last `1` is the same as just summing the number of 1.
+ k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1)
+
+ # calculate tau(z)
+ indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1)
+ tau_sum = array_ops.gather_nd(z_cumsum, indices)
+ tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype)
+
+ # calculate p
+ return math_ops.maximum(
+ math_ops.cast(0, logits.dtype),
+ z - tau_z[:, array_ops.newaxis]
+ )
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
new file mode 100644
index 0000000000..1f5e8c37e3
--- /dev/null
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
@@ -0,0 +1,59 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Sparsemax Loss op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+
+
+def sparsemax_loss(logits, sparsemax, labels, name=None):
+ """Computes sparsemax loss function [1].
+
+ [1]: https://arxiv.org/abs/1602.02068
+
+ Args:
+ logits: A `Tensor`. Must be one of the following types: `half`, `float32`,
+ `float64`.
+ sparsemax: A `Tensor`. Must have the same type as `logits`.
+ labels: A `Tensor`. Must have the same type as `logits`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor`. Has the same type as `logits`.
+ """
+
+ with ops.name_scope(name, "sparsemax_loss",
+ [logits, sparsemax, labels]) as name:
+ logits = ops.convert_to_tensor(logits, name="logits")
+ sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax")
+ labels = ops.convert_to_tensor(labels, name="labels")
+
+ shifted_logits = logits - \
+ math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+
+ # sum over support
+ support = math_ops.cast(sparsemax > 0, sparsemax.dtype)
+ sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax)
+
+ # - z_k + ||q||^2
+ q_part = labels * (0.5 * labels - shifted_logits)
+
+ return math_ops.reduce_sum(sum_s + q_part, axis=1)