diff options
-rw-r--r-- | tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py | 12 |
2 files changed, 6 insertions, 12 deletions
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py index 98eac39683..16102aec6a 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py @@ -105,8 +105,8 @@ class LinearOperatorDiagtest( def test_broadcast_apply_and_solve(self): # These cannot be done in the automated (base test class) tests since they - # test shapes that tf.batch_matmul cannot handle. - # In particular, tf.batch_matmul does not broadcast. + # test shapes that tf.matmul cannot handle. + # In particular, tf.matmul does not broadcast. with self.test_session() as sess: x = tf.random_normal(shape=(2, 2, 3, 4)) @@ -122,7 +122,7 @@ class LinearOperatorDiagtest( self.assertAllEqual((2, 2, 3, 3), mat.get_shape()) # being pedantic. operator_apply = operator.apply(x) - mat_apply = tf.batch_matmul(mat, x) + mat_apply = tf.matmul(mat, x) self.assertAllEqual(operator_apply.get_shape(), mat_apply.get_shape()) self.assertAllClose(*sess.run([operator_apply, mat_apply])) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py index adbdb9b3d2..2a2700b492 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py @@ -74,12 +74,6 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase): """Make a rhs appropriate for calling operator.apply(rhs).""" raise NotImplementedError("_make_x is not defined.") - def _maybe_adjoint(self, x, adjoint): - if adjoint: - return tf.matrix_transpose(x) - else: - return x - def test_to_dense(self): with self.test_session() as sess: for shape in self._shapes_to_test: @@ -134,7 +128,7 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase): continue x = self._make_x(operator) op_apply = operator.apply(x, adjoint=adjoint) - mat_apply = tf.batch_matmul(self._maybe_adjoint(mat, adjoint), x) + mat_apply = tf.matmul(mat, x, adjoint_a=adjoint) self.assertAllEqual(op_apply.get_shape(), mat_apply.get_shape()) op_apply_v, mat_apply_v = sess.run([op_apply, mat_apply]) self.assertAllClose(op_apply_v, mat_apply_v) @@ -147,7 +141,7 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase): shape, dtype, use_placeholder=True) x = self._make_x(operator) op_apply_v, mat_apply_v = sess.run( - [operator.apply(x), tf.batch_matmul(mat, x)], + [operator.apply(x), tf.matmul(mat, x)], feed_dict=feed_dict) self.assertAllClose(op_apply_v, mat_apply_v) @@ -162,7 +156,7 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase): continue rhs = self._make_rhs(operator) op_solve = operator.solve(rhs, adjoint=adjoint) - mat_solve = tf.matrix_solve(self._maybe_adjoint(mat, adjoint), rhs) + mat_solve = tf.matrix_solve(mat, rhs, adjoint=adjoint) self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape()) op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve]) self.assertAllClose(op_solve_v, mat_solve_v) |