diff options
author | 2016-12-14 16:30:24 -0800 | |
---|---|---|
committer | 2016-12-14 16:43:13 -0800 | |
commit | 5866e065bc95c1d7de8a27413b368016941889a6 (patch) | |
tree | 55b7db600e38b3a799ab39053cd99e61204f840b /tensorflow/python/kernel_tests/svd_op_test.py | |
parent | 38a664cd961762e64899187a31a1b86cbe5a992e (diff) |
Remove hourglass imports from kernel_tests
Change: 142080137
Diffstat (limited to 'tensorflow/python/kernel_tests/svd_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/svd_op_test.py | 54 |
1 files changed, 30 insertions, 24 deletions
diff --git a/tensorflow/python/kernel_tests/svd_op_test.py b/tensorflow/python/kernel_tests/svd_op_test.py index 2934e90ea2..3f6b6958fc 100644 --- a/tensorflow/python/kernel_tests/svd_op_test.py +++ b/tensorflow/python/kernel_tests/svd_op_test.py @@ -13,26 +13,32 @@ # limitations under the License. # ============================================================================== """Tests for tensorflow.ops.math_ops.matrix_inverse.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -import tensorflow as tf + +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test -class SvdOpTest(tf.test.TestCase): +class SvdOpTest(test.TestCase): def testWrongDimensions(self): # The input to svd should be a tensor of at least rank 2. - scalar = tf.constant(1.) + scalar = constant_op.constant(1.) with self.assertRaisesRegexp(ValueError, "Shape must be at least rank 2 but is rank 0"): - tf.svd(scalar) - vector = tf.constant([1., 2.]) + linalg_ops.svd(scalar) + vector = constant_op.constant([1., 2.]) with self.assertRaisesRegexp(ValueError, "Shape must be at least rank 2 but is rank 1"): - tf.svd(vector) + linalg_ops.svd(vector) def _GetSvdOpTest(dtype_, shape_, use_static_shape_): @@ -77,22 +83,22 @@ def _GetSvdOpTest(dtype_, shape_, use_static_shape_): batch_shape = a.shape[:-2] m = a.shape[-2] n = a.shape[-1] - diag_s = tf.cast(tf.matrix_diag(s), dtype=dtype_) + diag_s = math_ops.cast(array_ops.matrix_diag(s), dtype=dtype_) if full_matrices: if m > n: - zeros = tf.zeros(batch_shape + (m - n, n), dtype=dtype_) - diag_s = tf.concat_v2([diag_s, zeros], a.ndim - 2) + zeros = array_ops.zeros(batch_shape + (m - n, n), dtype=dtype_) + diag_s = array_ops.concat_v2([diag_s, zeros], a.ndim - 2) elif n > m: - zeros = tf.zeros(batch_shape + (m, n - m), dtype=dtype_) - diag_s = tf.concat_v2([diag_s, zeros], a.ndim - 1) - a_recon = tf.matmul(u, diag_s) - a_recon = tf.matmul(a_recon, v, adjoint_b=True) + zeros = array_ops.zeros(batch_shape + (m, n - m), dtype=dtype_) + diag_s = array_ops.concat_v2([diag_s, zeros], a.ndim - 1) + a_recon = math_ops.matmul(u, diag_s) + a_recon = math_ops.matmul(a_recon, v, adjoint_b=True) self.assertAllClose(a_recon.eval(), a, rtol=tol, atol=tol) def CheckUnitary(self, x): # Tests that x[...,:,:]^H * x[...,:,:] is close to the identity. - xx = tf.matmul(x, x, adjoint_a=True) - identity = tf.matrix_band_part(tf.ones_like(xx), 0, 0) + xx = math_ops.matmul(x, x, adjoint_a=True) + identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0) if is_single: tol = 1e-5 else: @@ -112,23 +118,23 @@ def _GetSvdOpTest(dtype_, shape_, use_static_shape_): for full_matrices in False, True: with self.test_session() as sess: if use_static_shape_: - x_tf = tf.constant(x_np) + x_tf = constant_op.constant(x_np) else: - x_tf = tf.placeholder(dtype_) + x_tf = array_ops.placeholder(dtype_) if compute_uv: - s_tf, u_tf, v_tf = tf.svd(x_tf, - compute_uv=compute_uv, - full_matrices=full_matrices) + s_tf, u_tf, v_tf = linalg_ops.svd(x_tf, + compute_uv=compute_uv, + full_matrices=full_matrices) if use_static_shape_: s_tf_val, u_tf_val, v_tf_val = sess.run([s_tf, u_tf, v_tf]) else: s_tf_val, u_tf_val, v_tf_val = sess.run([s_tf, u_tf, v_tf], feed_dict={x_tf: x_np}) else: - s_tf = tf.svd(x_tf, - compute_uv=compute_uv, - full_matrices=full_matrices) + s_tf = linalg_ops.svd(x_tf, + compute_uv=compute_uv, + full_matrices=full_matrices) if use_static_shape_: s_tf_val = sess.run(s_tf) else: @@ -171,4 +177,4 @@ if __name__ == "__main__": use_static_shape) setattr(SvdOpTest, "testSvd_" + name, _GetSvdOpTest(dtype, shape, use_static_shape)) - tf.test.main() + test.main() |