aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sergio Guadarrama <sguada@google.com>2017-06-15 13:07:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-15 13:11:18 -0700
commit9bba02c6ed00d59d0dc5bc9bb6f8a32662ef8103 (patch)
tree6b3d8402d038b287ff92aaa1811b87192296293f
parentfa75f26351f42e4fd3fc89b553d7919a6f147e41 (diff)
Add poincare_normalization
PiperOrigin-RevId: 159142705
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py38
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py64
2 files changed, 102 insertions, 0 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index d3b1094963..393a2488bc 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -2145,6 +2145,44 @@ def unit_norm(inputs, dim, epsilon=1e-7, scope=None):
return math_ops.div(inputs, array_ops.tile(lengths, multiples))
+def poincare_normalize(x, axis=1, epsilon=1e-5, name=None):
+ """Project into the Poincare ball with norm <= 1.0 - epsilon.
+
+ https://en.wikipedia.org/wiki/Poincare_ball_model
+
+ Used in
+ Poincare Embeddings for Learning Hierarchical Representations
+ Maximilian Nickel, Douwe Kiela
+ https://arxiv.org/pdf/1705.08039.pdf
+
+ For a 1-D tensor with `axis = 0`, computes
+
+ (x * (1 - epsilon)) / ||x|| if ||x|| > 1 - epsilon
+ output =
+ x otherwise
+
+ For `x` with more dimensions, independently normalizes each 1-D slice along
+ dimension `axis`.
+
+ Args:
+ x: A `Tensor`.
+ axis: Axis along which to normalize. A scalar or a vector of
+ integers.
+ epsilon: A small deviation from the edge of the unit sphere for numerical
+ stability.
+ name: A name for this operation (optional).
+
+ Returns:
+ A `Tensor` with the same shape as `x`.
+ """
+ with ops.name_scope(name, 'poincare_normalize', [x]) as name:
+ x = ops.convert_to_tensor(x, name='x')
+ square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keep_dims=True)
+ x_inv_norm = math_ops.rsqrt(square_sum)
+ x_inv_norm = math_ops.minimum((1. - epsilon) * x_inv_norm, 1.)
+ return math_ops.multiply(x, x_inv_norm, name=name)
+
+
def legacy_fully_connected(x,
num_output_units,
activation_fn=None,
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index b49c33e996..67f45473e8 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -3231,6 +3232,69 @@ class UnitNormTests(test.TestCase):
self.assertAllClose(expected, actual, 1e-4, 1e-4)
+class PoincareNormalizeTest(test.TestCase):
+
+ def _PoincareNormalize(self, x, dim, epsilon=1e-5):
+ if isinstance(dim, list):
+ norm = np.linalg.norm(x, axis=tuple(dim))
+ for d in dim:
+ norm = np.expand_dims(norm, d)
+ norm_x = ((1. - epsilon) * x) / norm
+ else:
+ norm = np.expand_dims(np.apply_along_axis(np.linalg.norm, dim, x), dim)
+ norm_x = ((1. - epsilon) * x) / norm
+ return np.where(norm > 1.0 - epsilon, norm_x, x)
+
+ def testPoincareNormalize(self):
+ x_shape = [20, 7, 3]
+ epsilon = 1e-5
+ tol = 1e-6
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float32)
+ for dim in range(len(x_shape)):
+ y_np = self._PoincareNormalize(x_np, dim, epsilon)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, name='x')
+ y_tf = _layers.poincare_normalize(x_tf, dim, epsilon)
+ y_tf_eval = y_tf.eval()
+ norm = np.linalg.norm(y_np, axis=dim)
+ self.assertLessEqual(norm.max(), 1. - epsilon + tol)
+ norm = np.linalg.norm(y_tf_eval, axis=dim)
+ self.assertLessEqual(norm.max(), 1. - epsilon + tol)
+ self.assertAllClose(y_np, y_tf_eval)
+
+ def testPoincareNormalizeDimArray(self):
+ x_shape = [20, 7, 3]
+ epsilon = 1e-5
+ tol = 1e-6
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float32)
+ dim = [1, 2]
+ y_np = self._PoincareNormalize(x_np, dim, epsilon)
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, name='x')
+ y_tf = _layers.poincare_normalize(x_tf, dim, epsilon)
+ y_tf_eval = y_tf.eval()
+ norm = np.linalg.norm(y_np, axis=tuple(dim))
+ self.assertLess(norm.max(), 1. - epsilon + tol)
+ norm = np.linalg.norm(y_tf_eval, axis=tuple(dim))
+ self.assertLess(norm.max(), 1. - epsilon + tol)
+ self.assertAllClose(y_np, y_tf_eval, rtol=1e-6, atol=1e-6)
+
+ def testPoincareNormalizeGradient(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float64)
+ for dim in range(len(x_shape)):
+ with self.test_session():
+ x_tf = constant_op.constant(x_np, name='x')
+ y_tf = _layers.poincare_normalize(x_tf, dim)
+ err = gradient_checker.compute_gradient_error(x_tf, x_shape,
+ y_tf, x_shape)
+ print('PoinCareNormalize gradient err = %g ' % err)
+ self.assertLess(err, 1e-4)
+
+
# TODO(b/28426988): Add separate tests for non-legacy versions.
class LegacyFullyConnectedTest(test.TestCase):