aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/svd_op_test.py
diff options
context:
space:
mode:
authorGravatar Justine Tunney <jart@google.com>2016-12-14 16:30:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-14 16:43:13 -0800
commit5866e065bc95c1d7de8a27413b368016941889a6 (patch)
tree55b7db600e38b3a799ab39053cd99e61204f840b /tensorflow/python/kernel_tests/svd_op_test.py
parent38a664cd961762e64899187a31a1b86cbe5a992e (diff)
Remove hourglass imports from kernel_tests
Change: 142080137
Diffstat (limited to 'tensorflow/python/kernel_tests/svd_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/svd_op_test.py54
1 files changed, 30 insertions, 24 deletions
diff --git a/tensorflow/python/kernel_tests/svd_op_test.py b/tensorflow/python/kernel_tests/svd_op_test.py
index 2934e90ea2..3f6b6958fc 100644
--- a/tensorflow/python/kernel_tests/svd_op_test.py
+++ b/tensorflow/python/kernel_tests/svd_op_test.py
@@ -13,26 +13,32 @@
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.math_ops.matrix_inverse."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
-import tensorflow as tf
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
-class SvdOpTest(tf.test.TestCase):
+class SvdOpTest(test.TestCase):
def testWrongDimensions(self):
# The input to svd should be a tensor of at least rank 2.
- scalar = tf.constant(1.)
+ scalar = constant_op.constant(1.)
with self.assertRaisesRegexp(ValueError,
"Shape must be at least rank 2 but is rank 0"):
- tf.svd(scalar)
- vector = tf.constant([1., 2.])
+ linalg_ops.svd(scalar)
+ vector = constant_op.constant([1., 2.])
with self.assertRaisesRegexp(ValueError,
"Shape must be at least rank 2 but is rank 1"):
- tf.svd(vector)
+ linalg_ops.svd(vector)
def _GetSvdOpTest(dtype_, shape_, use_static_shape_):
@@ -77,22 +83,22 @@ def _GetSvdOpTest(dtype_, shape_, use_static_shape_):
batch_shape = a.shape[:-2]
m = a.shape[-2]
n = a.shape[-1]
- diag_s = tf.cast(tf.matrix_diag(s), dtype=dtype_)
+ diag_s = math_ops.cast(array_ops.matrix_diag(s), dtype=dtype_)
if full_matrices:
if m > n:
- zeros = tf.zeros(batch_shape + (m - n, n), dtype=dtype_)
- diag_s = tf.concat_v2([diag_s, zeros], a.ndim - 2)
+ zeros = array_ops.zeros(batch_shape + (m - n, n), dtype=dtype_)
+ diag_s = array_ops.concat_v2([diag_s, zeros], a.ndim - 2)
elif n > m:
- zeros = tf.zeros(batch_shape + (m, n - m), dtype=dtype_)
- diag_s = tf.concat_v2([diag_s, zeros], a.ndim - 1)
- a_recon = tf.matmul(u, diag_s)
- a_recon = tf.matmul(a_recon, v, adjoint_b=True)
+ zeros = array_ops.zeros(batch_shape + (m, n - m), dtype=dtype_)
+ diag_s = array_ops.concat_v2([diag_s, zeros], a.ndim - 1)
+ a_recon = math_ops.matmul(u, diag_s)
+ a_recon = math_ops.matmul(a_recon, v, adjoint_b=True)
self.assertAllClose(a_recon.eval(), a, rtol=tol, atol=tol)
def CheckUnitary(self, x):
# Tests that x[...,:,:]^H * x[...,:,:] is close to the identity.
- xx = tf.matmul(x, x, adjoint_a=True)
- identity = tf.matrix_band_part(tf.ones_like(xx), 0, 0)
+ xx = math_ops.matmul(x, x, adjoint_a=True)
+ identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
if is_single:
tol = 1e-5
else:
@@ -112,23 +118,23 @@ def _GetSvdOpTest(dtype_, shape_, use_static_shape_):
for full_matrices in False, True:
with self.test_session() as sess:
if use_static_shape_:
- x_tf = tf.constant(x_np)
+ x_tf = constant_op.constant(x_np)
else:
- x_tf = tf.placeholder(dtype_)
+ x_tf = array_ops.placeholder(dtype_)
if compute_uv:
- s_tf, u_tf, v_tf = tf.svd(x_tf,
- compute_uv=compute_uv,
- full_matrices=full_matrices)
+ s_tf, u_tf, v_tf = linalg_ops.svd(x_tf,
+ compute_uv=compute_uv,
+ full_matrices=full_matrices)
if use_static_shape_:
s_tf_val, u_tf_val, v_tf_val = sess.run([s_tf, u_tf, v_tf])
else:
s_tf_val, u_tf_val, v_tf_val = sess.run([s_tf, u_tf, v_tf],
feed_dict={x_tf: x_np})
else:
- s_tf = tf.svd(x_tf,
- compute_uv=compute_uv,
- full_matrices=full_matrices)
+ s_tf = linalg_ops.svd(x_tf,
+ compute_uv=compute_uv,
+ full_matrices=full_matrices)
if use_static_shape_:
s_tf_val = sess.run(s_tf)
else:
@@ -171,4 +177,4 @@ if __name__ == "__main__":
use_static_shape)
setattr(SvdOpTest, "testSvd_" + name,
_GetSvdOpTest(dtype, shape, use_static_shape))
- tf.test.main()
+ test.main()