diff options
Diffstat (limited to 'tensorflow/contrib/solvers/python/kernel_tests/util_test.py')
-rw-r--r-- | tensorflow/contrib/solvers/python/kernel_tests/util_test.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py index 1566984b27..5d7534657b 100644 --- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py +++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py @@ -63,6 +63,43 @@ class UtilTest(test.TestCase): def testCreateOperatorUnknownShape(self): self._testCreateOperator(False) + def _testIdentityOperator(self, use_static_shape_): + for dtype in np.float32, np.float64: + a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype) + x_np = np.array([[2.], [-3.]], dtype=dtype) + y_np = np.array([[2], [-3.], [5.]], dtype=dtype) + with self.test_session() as sess: + if use_static_shape_: + a = constant_op.constant(a_np, dtype=dtype) + x = constant_op.constant(x_np, dtype=dtype) + y = constant_op.constant(y_np, dtype=dtype) + else: + a = array_ops.placeholder(dtype) + x = array_ops.placeholder(dtype) + y = array_ops.placeholder(dtype) + id_op = util.identity_operator(a) + ax = id_op.apply(x) + aty = id_op.apply_adjoint(y) + op_shape = ops.convert_to_tensor(id_op.shape) + if use_static_shape_: + op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty]) + else: + op_shape_val, ax_val, aty_val = sess.run( + [op_shape, ax, aty], feed_dict={ + a: a_np, + x: x_np, + y: y_np + }) + self.assertAllEqual(op_shape_val, [3, 2]) + self.assertAllClose(ax_val, x_np) + self.assertAllClose(aty_val, y_np) + + def testIdentityOperator(self): + self._testIdentityOperator(True) + + def testIdentityOperatorUnknownShape(self): + self._testIdentityOperator(False) + def testL2Norm(self): with self.test_session(): x_np = np.array([[2], [-3.], [5.]]) |