diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-02-08 12:16:52 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-08 12:27:27 -0800 |
commit | 83d350275c6835b818168e704fb1329d410b63a9 (patch) | |
tree | 30d27aec6eb095cfad0a46f7e0089fc522dfbdd3 /tensorflow/contrib/sparsemax | |
parent | 085102c2e2947d76056b6363da96c55ecd838e6c (diff) |
Seal the sparsemax module
Change: 146941104
Diffstat (limited to 'tensorflow/contrib/sparsemax')
-rw-r--r-- | tensorflow/contrib/sparsemax/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/sparsemax/__init__.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/sparsemax/python/ops/sparsemax.py | 17 | ||||
-rw-r--r-- | tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py | 3 |
4 files changed, 18 insertions, 8 deletions
diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD index bd59c626f2..7441f1429f 100644 --- a/tensorflow/contrib/sparsemax/BUILD +++ b/tensorflow/contrib/sparsemax/BUILD @@ -27,6 +27,7 @@ py_library( deps = [ "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", + "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:nn", diff --git a/tensorflow/contrib/sparsemax/__init__.py b/tensorflow/contrib/sparsemax/__init__.py index 0be4988dbf..19d213fb3e 100644 --- a/tensorflow/contrib/sparsemax/__init__.py +++ b/tensorflow/contrib/sparsemax/__init__.py @@ -28,3 +28,8 @@ from __future__ import print_function from tensorflow.contrib.sparsemax.python.ops.sparsemax import sparsemax from tensorflow.contrib.sparsemax.python.ops.sparsemax_loss \ import sparsemax_loss +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['sparsemax', 'sparsemax_loss'] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py index 6e1cd75f22..07ac24add4 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py @@ -13,16 +13,20 @@ # limitations under the License. # ============================================================================== """Sparsemax op.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.contrib.util import loader -from tensorflow.python.platform import resource_loader -from tensorflow.python.framework import ops, dtypes -from tensorflow.python.ops import math_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.platform import resource_loader + +__all__ = ["sparsemax"] def sparsemax(logits, name=None): @@ -55,8 +59,7 @@ def sparsemax(logits, name=None): # calculate k(z) z_cumsum = math_ops.cumsum(z_sorted, axis=1) k = math_ops.range( - 1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype - ) + 1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype) z_check = 1 + k * z_sorted > z_cumsum # because the z_check vector is always [1,1,...1,0,0,...0] finding the # (index + 1) of the last `1` is the same as just summing the number of 1. @@ -69,6 +72,4 @@ def sparsemax(logits, name=None): # calculate p return math_ops.maximum( - math_ops.cast(0, logits.dtype), - z - tau_z[:, array_ops.newaxis] - ) + math_ops.cast(0, logits.dtype), z - tau_z[:, array_ops.newaxis]) diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py index 1f5e8c37e3..ba18f89e16 100644 --- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py +++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Sparsemax Loss op.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -23,6 +24,8 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +__all__ = ["sparsemax_loss"] + def sparsemax_loss(logits, sparsemax, labels, name=None): """Computes sparsemax loss function [1]. |