aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linalg
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-29 08:28:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-29 08:49:28 -0800
commitdf71c58cd4a5ec9a64f8a8ad5af02ecec42b38c5 (patch)
tree5074ab424dc58a1e2323c2f23f50723689db823d /tensorflow/contrib/linalg
parent99700a09c632ad14d99a54d8f1db64928e32d8c6 (diff)
Remove tf.batch_matmul Python interface.
Change: 140481905
Diffstat (limited to 'tensorflow/contrib/linalg')
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py6
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py12
2 files changed, 6 insertions, 12 deletions
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py
index 98eac39683..16102aec6a 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py
@@ -105,8 +105,8 @@ class LinearOperatorDiagtest(
def test_broadcast_apply_and_solve(self):
# These cannot be done in the automated (base test class) tests since they
- # test shapes that tf.batch_matmul cannot handle.
- # In particular, tf.batch_matmul does not broadcast.
+ # test shapes that tf.matmul cannot handle.
+ # In particular, tf.matmul does not broadcast.
with self.test_session() as sess:
x = tf.random_normal(shape=(2, 2, 3, 4))
@@ -122,7 +122,7 @@ class LinearOperatorDiagtest(
self.assertAllEqual((2, 2, 3, 3), mat.get_shape()) # being pedantic.
operator_apply = operator.apply(x)
- mat_apply = tf.batch_matmul(mat, x)
+ mat_apply = tf.matmul(mat, x)
self.assertAllEqual(operator_apply.get_shape(), mat_apply.get_shape())
self.assertAllClose(*sess.run([operator_apply, mat_apply]))
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
index adbdb9b3d2..2a2700b492 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
@@ -74,12 +74,6 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase):
"""Make a rhs appropriate for calling operator.apply(rhs)."""
raise NotImplementedError("_make_x is not defined.")
- def _maybe_adjoint(self, x, adjoint):
- if adjoint:
- return tf.matrix_transpose(x)
- else:
- return x
-
def test_to_dense(self):
with self.test_session() as sess:
for shape in self._shapes_to_test:
@@ -134,7 +128,7 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase):
continue
x = self._make_x(operator)
op_apply = operator.apply(x, adjoint=adjoint)
- mat_apply = tf.batch_matmul(self._maybe_adjoint(mat, adjoint), x)
+ mat_apply = tf.matmul(mat, x, adjoint_a=adjoint)
self.assertAllEqual(op_apply.get_shape(), mat_apply.get_shape())
op_apply_v, mat_apply_v = sess.run([op_apply, mat_apply])
self.assertAllClose(op_apply_v, mat_apply_v)
@@ -147,7 +141,7 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase):
shape, dtype, use_placeholder=True)
x = self._make_x(operator)
op_apply_v, mat_apply_v = sess.run(
- [operator.apply(x), tf.batch_matmul(mat, x)],
+ [operator.apply(x), tf.matmul(mat, x)],
feed_dict=feed_dict)
self.assertAllClose(op_apply_v, mat_apply_v)
@@ -162,7 +156,7 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase):
continue
rhs = self._make_rhs(operator)
op_solve = operator.solve(rhs, adjoint=adjoint)
- mat_solve = tf.matrix_solve(self._maybe_adjoint(mat, adjoint), rhs)
+ mat_solve = tf.matrix_solve(mat, rhs, adjoint=adjoint)
self.assertAllEqual(op_solve.get_shape(), mat_solve.get_shape())
op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
self.assertAllClose(op_solve_v, mat_solve_v)