aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/solvers/python/kernel_tests/util_test.py')
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/util_test.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
index 1566984b27..5d7534657b 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
@@ -63,6 +63,43 @@ class UtilTest(test.TestCase):
def testCreateOperatorUnknownShape(self):
self._testCreateOperator(False)
+ def _testIdentityOperator(self, use_static_shape_):
+ for dtype in np.float32, np.float64:
+ a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype)
+ x_np = np.array([[2.], [-3.]], dtype=dtype)
+ y_np = np.array([[2], [-3.], [5.]], dtype=dtype)
+ with self.test_session() as sess:
+ if use_static_shape_:
+ a = constant_op.constant(a_np, dtype=dtype)
+ x = constant_op.constant(x_np, dtype=dtype)
+ y = constant_op.constant(y_np, dtype=dtype)
+ else:
+ a = array_ops.placeholder(dtype)
+ x = array_ops.placeholder(dtype)
+ y = array_ops.placeholder(dtype)
+ id_op = util.identity_operator(a)
+ ax = id_op.apply(x)
+ aty = id_op.apply_adjoint(y)
+ op_shape = ops.convert_to_tensor(id_op.shape)
+ if use_static_shape_:
+ op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty])
+ else:
+ op_shape_val, ax_val, aty_val = sess.run(
+ [op_shape, ax, aty], feed_dict={
+ a: a_np,
+ x: x_np,
+ y: y_np
+ })
+ self.assertAllEqual(op_shape_val, [3, 2])
+ self.assertAllClose(ax_val, x_np)
+ self.assertAllClose(aty_val, y_np)
+
+ def testIdentityOperator(self):
+ self._testIdentityOperator(True)
+
+ def testIdentityOperatorUnknownShape(self):
+ self._testIdentityOperator(False)
+
def testL2Norm(self):
with self.test_session():
x_np = np.array([[2], [-3.], [5.]])