aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/norm_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/norm_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/norm_op_test.py16
1 files changed, 7 insertions, 9 deletions
diff --git a/tensorflow/python/kernel_tests/norm_op_test.py b/tensorflow/python/kernel_tests/norm_op_test.py
index d85512fae6..3f71b326a2 100644
--- a/tensorflow/python/kernel_tests/norm_op_test.py
+++ b/tensorflow/python/kernel_tests/norm_op_test.py
@@ -37,17 +37,17 @@ class NormOpTest(test_lib.TestCase):
def testBadOrder(self):
matrix = [[0., 1.], [2., 3.]]
- for ord_ in "foo", -7, -1.1, 0:
+ for ord_ in "fro", -7, -1.1, 0:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported vector norm"):
- linalg_ops.norm(matrix, ord="fro")
+ linalg_ops.norm(matrix, ord=ord_)
- for ord_ in "foo", -7, -1.1, 0:
+ for ord_ in "fro", -7, -1.1, 0:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported vector norm"):
linalg_ops.norm(matrix, ord=ord_, axis=-1)
- for ord_ in 1.1, 2:
+ for ord_ in "foo", -7, -1.1, 1.1:
with self.assertRaisesRegexp(ValueError,
"'ord' must be a supported matrix norm"):
linalg_ops.norm(matrix, ord=ord_, axis=[-2, -1])
@@ -69,14 +69,14 @@ def _GetNormOpTest(dtype_, shape_, ord_, axis_, keep_dims_, use_static_shape_):
if use_static_shape_:
tf_matrix = constant_op.constant(matrix)
tf_norm = linalg_ops.norm(
- tf_matrix, ord=ord_, axis=axis_, keep_dims=keep_dims_)
+ tf_matrix, ord=ord_, axis=axis_, keepdims=keep_dims_)
tf_norm_val = sess.run(tf_norm)
else:
tf_matrix = array_ops.placeholder(dtype_)
tf_norm = linalg_ops.norm(
- tf_matrix, ord=ord_, axis=axis_, keep_dims=keep_dims_)
+ tf_matrix, ord=ord_, axis=axis_, keepdims=keep_dims_)
tf_norm_val = sess.run(tf_norm, feed_dict={tf_matrix: matrix})
- self.assertAllClose(np_norm, tf_norm_val)
+ self.assertAllClose(np_norm, tf_norm_val, rtol=1e-5, atol=1e-5)
def Test(self):
is_matrix_norm = (isinstance(axis_, tuple) or
@@ -85,8 +85,6 @@ def _GetNormOpTest(dtype_, shape_, ord_, axis_, keep_dims_, use_static_shape_):
if ((not is_matrix_norm and ord_ == "fro") or
(is_matrix_norm and is_fancy_p_norm)):
self.skipTest("Not supported by neither numpy.linalg.norm nor tf.norm")
- if is_matrix_norm and ord_ == 2:
- self.skipTest("Not supported by tf.norm")
if ord_ == 'euclidean' or (axis_ is None and len(shape) > 2):
self.skipTest("Not supported by numpy.linalg.norm")
matrix = np.random.randn(*shape_).astype(dtype_)