aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/sparsemax
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-02-08 12:16:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-08 12:27:27 -0800
commit83d350275c6835b818168e704fb1329d410b63a9 (patch)
tree30d27aec6eb095cfad0a46f7e0089fc522dfbdd3 /tensorflow/contrib/sparsemax
parent085102c2e2947d76056b6363da96c55ecd838e6c (diff)
Seal the sparsemax module
Change: 146941104
Diffstat (limited to 'tensorflow/contrib/sparsemax')
-rw-r--r--tensorflow/contrib/sparsemax/BUILD1
-rw-r--r--tensorflow/contrib/sparsemax/__init__.py5
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax.py17
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py3
4 files changed, 18 insertions, 8 deletions
diff --git a/tensorflow/contrib/sparsemax/BUILD b/tensorflow/contrib/sparsemax/BUILD
index bd59c626f2..7441f1429f 100644
--- a/tensorflow/contrib/sparsemax/BUILD
+++ b/tensorflow/contrib/sparsemax/BUILD
@@ -27,6 +27,7 @@ py_library(
deps = [
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
diff --git a/tensorflow/contrib/sparsemax/__init__.py b/tensorflow/contrib/sparsemax/__init__.py
index 0be4988dbf..19d213fb3e 100644
--- a/tensorflow/contrib/sparsemax/__init__.py
+++ b/tensorflow/contrib/sparsemax/__init__.py
@@ -28,3 +28,8 @@ from __future__ import print_function
from tensorflow.contrib.sparsemax.python.ops.sparsemax import sparsemax
from tensorflow.contrib.sparsemax.python.ops.sparsemax_loss \
import sparsemax_loss
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = ['sparsemax', 'sparsemax_loss']
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
index 6e1cd75f22..07ac24add4 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
@@ -13,16 +13,20 @@
# limitations under the License.
# ==============================================================================
"""Sparsemax op."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.util import loader
-from tensorflow.python.platform import resource_loader
-from tensorflow.python.framework import ops, dtypes
-from tensorflow.python.ops import math_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
+from tensorflow.python.platform import resource_loader
+
+__all__ = ["sparsemax"]
def sparsemax(logits, name=None):
@@ -55,8 +59,7 @@ def sparsemax(logits, name=None):
# calculate k(z)
z_cumsum = math_ops.cumsum(z_sorted, axis=1)
k = math_ops.range(
- 1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype
- )
+ 1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype)
z_check = 1 + k * z_sorted > z_cumsum
# because the z_check vector is always [1,1,...1,0,0,...0] finding the
# (index + 1) of the last `1` is the same as just summing the number of 1.
@@ -69,6 +72,4 @@ def sparsemax(logits, name=None):
# calculate p
return math_ops.maximum(
- math_ops.cast(0, logits.dtype),
- z - tau_z[:, array_ops.newaxis]
- )
+ math_ops.cast(0, logits.dtype), z - tau_z[:, array_ops.newaxis])
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
index 1f5e8c37e3..ba18f89e16 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
@@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
"""Sparsemax Loss op."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -23,6 +24,8 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+__all__ = ["sparsemax_loss"]
+
def sparsemax_loss(logits, sparsemax, labels, name=None):
"""Computes sparsemax loss function [1].