diff options
author | 2017-02-08 09:25:09 -0800 | |
---|---|---|
committer | 2017-02-08 09:50:05 -0800 | |
commit | 639b4e71f532761a4840b1cdbaea55ad0917c75b (patch) | |
tree | 5116415b1d9ff82f054dd4feeadd81cb833d6435 /tensorflow/contrib/sparsemax/python/ops | |
parent | 15ff7b702788c0cf75bb8d5ce090f06490098cf7 (diff) |
Merge changes from github.
Change: 146918929
Diffstat (limited to 'tensorflow/contrib/sparsemax/python/ops')
-rw-r--r-- | tensorflow/contrib/sparsemax/python/ops/sparsemax.py | 74 | ||||
-rw-r--r-- | tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py | 59 |
2 files changed, 133 insertions, 0 deletions
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py new file mode 100644 index 0000000000..6e1cd75f22 --- /dev/null +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py @@ -0,0 +1,74 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sparsemax op.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader +from tensorflow.python.framework import ops, dtypes +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn + + +def sparsemax(logits, name=None): + """Computes sparsemax activations [1]. + + For each batch `i` and class `j` we have + sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0) + + [1]: https://arxiv.org/abs/1602.02068 + + Args: + logits: A `Tensor`. Must be one of the following types: `half`, `float32`, + `float64`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `logits`. + """ + + with ops.name_scope(name, "sparsemax", [logits]) as name: + logits = ops.convert_to_tensor(logits, name="logits") + obs = array_ops.shape(logits)[0] + dims = array_ops.shape(logits)[1] + + z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis] + + # sort z + z_sorted, _ = nn.top_k(z, k=dims) + + # calculate k(z) + z_cumsum = math_ops.cumsum(z_sorted, axis=1) + k = math_ops.range( + 1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype + ) + z_check = 1 + k * z_sorted > z_cumsum + # because the z_check vector is always [1,1,...1,0,0,...0] finding the + # (index + 1) of the last `1` is the same as just summing the number of 1. + k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1) + + # calculate tau(z) + indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1) + tau_sum = array_ops.gather_nd(z_cumsum, indices) + tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype) + + # calculate p + return math_ops.maximum( + math_ops.cast(0, logits.dtype), + z - tau_z[:, array_ops.newaxis] + ) diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py new file mode 100644 index 0000000000..1f5e8c37e3 --- /dev/null +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py @@ -0,0 +1,59 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Sparsemax Loss op.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +def sparsemax_loss(logits, sparsemax, labels, name=None): + """Computes sparsemax loss function [1]. + + [1]: https://arxiv.org/abs/1602.02068 + + Args: + logits: A `Tensor`. Must be one of the following types: `half`, `float32`, + `float64`. + sparsemax: A `Tensor`. Must have the same type as `logits`. + labels: A `Tensor`. Must have the same type as `logits`. + name: A name for the operation (optional). + + Returns: + A `Tensor`. Has the same type as `logits`. + """ + + with ops.name_scope(name, "sparsemax_loss", + [logits, sparsemax, labels]) as name: + logits = ops.convert_to_tensor(logits, name="logits") + sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax") + labels = ops.convert_to_tensor(labels, name="labels") + + shifted_logits = logits - \ + math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis] + + # sum over support + support = math_ops.cast(sparsemax > 0, sparsemax.dtype) + sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax) + + # - z_k + ||q||^2 + q_part = labels * (0.5 * labels - shifted_logits) + + return math_ops.reduce_sum(sum_s + q_part, axis=1) |