aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-08-22 08:37:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-22 09:48:36 -0700
commit3b7153de39732170cc5abd01b0e051bf98e066eb (patch)
treeb229442282d99fb2076455ea7d0eeb8d08367875
parenta5bcbf1284ea7c191c08334a1c331559653c4118 (diff)
Fix L2Normalize when passing a list of dims. Fixes #3932.
Change: 130947456
-rw-r--r--tensorflow/python/ops/nn.py5
-rw-r--r--tensorflow/python/ops/nn_test.py21
2 files changed, 22 insertions, 4 deletions
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index fa3fba3375..fd4eedcfcc 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -547,7 +547,8 @@ def l2_normalize(x, dim, epsilon=1e-12, name=None):
Args:
x: A `Tensor`.
- dim: Dimension along which to normalize.
+ dim: Dimension along which to normalize. A scalar or a vector of
+ integers.
epsilon: A lower bound value for the norm. Will use `sqrt(epsilon)` as the
divisor if `norm < sqrt(epsilon)`.
name: A name for this operation (optional).
@@ -557,7 +558,7 @@ def l2_normalize(x, dim, epsilon=1e-12, name=None):
"""
with ops.name_scope(name, "l2_normalize", [x]) as name:
x = ops.convert_to_tensor(x, name="x")
- square_sum = math_ops.reduce_sum(math_ops.square(x), [dim], keep_dims=True)
+ square_sum = math_ops.reduce_sum(math_ops.square(x), dim, keep_dims=True)
x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon))
return math_ops.mul(x, x_inv_norm, name=name)
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 4ab346bcbd..d146af478f 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -180,8 +180,14 @@ class L2LossTest(tf.test.TestCase):
class L2NormalizeTest(tf.test.TestCase):
def _l2Normalize(self, x, dim):
- norm = np.apply_along_axis(np.linalg.norm, dim, x)
- return x / np.expand_dims(norm, dim)
+ if isinstance(dim, list):
+ norm = np.linalg.norm(x, axis=tuple(dim))
+ for d in dim:
+ norm = np.expand_dims(norm, d)
+ return x / norm
+ else:
+ norm = np.apply_along_axis(np.linalg.norm, dim, x)
+ return x / np.expand_dims(norm, dim)
def testL2Normalize(self):
x_shape = [20, 7, 3]
@@ -194,6 +200,17 @@ class L2NormalizeTest(tf.test.TestCase):
y_tf = tf.nn.l2_normalize(x_tf, dim)
self.assertAllClose(y_np, y_tf.eval())
+ def testL2NormalizeDimArray(self):
+ x_shape = [20, 7, 3]
+ np.random.seed(1)
+ x_np = np.random.random_sample(x_shape).astype(np.float32)
+ dim = [1, 2]
+ y_np = self._l2Normalize(x_np, dim)
+ with self.test_session():
+ x_tf = tf.constant(x_np, name="x")
+ y_tf = tf.nn.l2_normalize(x_tf, dim)
+ self.assertAllClose(y_np, y_tf.eval())
+
def testL2NormalizeGradient(self):
x_shape = [20, 7, 3]
np.random.seed(1)