aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/solvers
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-11-17 17:12:07 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-17 17:26:52 -0800
commit80913f02c87abd033f85185e005ffb4bb0bd09b1 (patch)
tree858eeceaa5f8c86ab1952887b3edea9314a22a09 /tensorflow/contrib/solvers
parentedbddbd13edd7dc03389e09e7e6c56d48acf8d29 (diff)
Deprecate tf.batch_matmul and replace with equivalent calls to tf.matmul that now supports adjoint and batch matmul.
CL created by: replace_string \ batch_matmul\\\( \ matmul\( plus some manual edits, mostly s/adj_x/adjoint_a/ s/adj_y/adjoint_b/. Change: 139527377
Diffstat (limited to 'tensorflow/contrib/solvers')
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py4
-rw-r--r--tensorflow/contrib/solvers/python/kernel_tests/util_test.py45
-rw-r--r--tensorflow/contrib/solvers/python/ops/lanczos.py2
-rw-r--r--tensorflow/contrib/solvers/python/ops/util.py4
4 files changed, 31 insertions, 24 deletions
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
index 7df736a293..5fea07cd83 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/lanczos_test.py
@@ -56,9 +56,9 @@ def _get_lanczos_tests(dtype_, use_static_shape_, shape_, orthogonalize_,
# The computed factorization should satisfy the equations
# A * V = U * B
# A' * U[:, :-1] = V * B[:-1, :]'
- av = tf.batch_matmul(a, lbd.v)
+ av = tf.matmul(a, lbd.v)
ub = lanczos.bidiag_matmul(lbd.u, lbd.alpha, lbd.beta, adjoint_b=False)
- atu = tf.batch_matmul(a, lbd.u[:, :-1], adj_x=True)
+ atu = tf.matmul(a, lbd.u[:, :-1], adjoint_a=True)
vbt = lanczos.bidiag_matmul(lbd.v, lbd.alpha, lbd.beta, adjoint_b=True)
if use_static_shape_:
diff --git a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
index 2a9ba0f893..c1d85546e8 100644
--- a/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
+++ b/tensorflow/contrib/solvers/python/kernel_tests/util_test.py
@@ -26,26 +26,33 @@ from tensorflow.contrib.solvers.python.ops import util
class UtilTest(tf.test.TestCase):
def _testCreateOperator(self, use_static_shape_):
- a_np = np.array([[1., 2.], [3., 4.], [5., 6.]])
- x = np.array([[2.], [-3.]])
- y = np.array([[2], [-3.], [5.]])
- with self.test_session() as sess:
- if use_static_shape_:
- a = tf.constant(a_np, dtype=tf.float32)
- else:
- a = tf.placeholder(tf.float32)
- op = util.create_operator(a)
- ax = op.apply(x)
- aty = op.apply_adjoint(y)
- op_shape = tf.convert_to_tensor(op.shape)
- if use_static_shape_:
- op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty])
- else:
- op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty],
- feed_dict={a: a_np})
+ for dtype in np.float32, np.float64:
+ a_np = np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=dtype)
+ x_np = np.array([[2.], [-3.]], dtype=dtype)
+ y_np = np.array([[2], [-3.], [5.]], dtype=dtype)
+ with self.test_session() as sess:
+ if use_static_shape_:
+ a = tf.constant(a_np, dtype=dtype)
+ x = tf.constant(x_np, dtype=dtype)
+ y = tf.constant(y_np, dtype=dtype)
+ else:
+ a = tf.placeholder(dtype)
+ x = tf.placeholder(dtype)
+ y = tf.placeholder(dtype)
+ op = util.create_operator(a)
+ ax = op.apply(x)
+ aty = op.apply_adjoint(y)
+ op_shape = tf.convert_to_tensor(op.shape)
+ if use_static_shape_:
+ op_shape_val, ax_val, aty_val = sess.run([op_shape, ax, aty])
+ else:
+ op_shape_val, ax_val, aty_val = sess.run(
+ [op_shape, ax, aty], feed_dict={a: a_np,
+ x: x_np,
+ y: y_np})
self.assertAllEqual(op_shape_val, [3, 2])
- self.assertAllClose(ax_val, [[-4], [-6], [-8]])
- self.assertAllClose(aty_val, [[18], [22]])
+ self.assertAllClose(ax_val, np.dot(a_np, x_np))
+ self.assertAllClose(aty_val, np.dot(a_np.T, y_np))
def testCreateOperator(self):
self._testCreateOperator(True)
diff --git a/tensorflow/contrib/solvers/python/ops/lanczos.py b/tensorflow/contrib/solvers/python/ops/lanczos.py
index 4e666c24dc..0a6c17eea2 100644
--- a/tensorflow/contrib/solvers/python/ops/lanczos.py
+++ b/tensorflow/contrib/solvers/python/ops/lanczos.py
@@ -119,7 +119,7 @@ def lanczos_bidiag(operator,
"""Makes v orthogonal to the j'th vector in basis."""
v_shape = v.get_shape()
basis_vec = read_colvec(basis, j)
- v -= tf.batch_matmul(basis_vec, v, adj_x=True) * basis_vec
+ v -= tf.matmul(basis_vec, v, adjoint_a=True) * basis_vec
v.set_shape(v_shape)
return j + 1, basis, v
diff --git a/tensorflow/contrib/solvers/python/ops/util.py b/tensorflow/contrib/solvers/python/ops/util.py
index e9bcb0d61d..ca1fb2918b 100644
--- a/tensorflow/contrib/solvers/python/ops/util.py
+++ b/tensorflow/contrib/solvers/python/ops/util.py
@@ -40,8 +40,8 @@ def create_operator(matrix):
dtype=matrix.dtype,
# TODO(rmlarsen): We are only using batch_matmul here because matmul
# only has transpose and not adjoint.
- apply=lambda v: tf.batch_matmul(matrix, v, adj_x=False),
- apply_adjoint=lambda v: tf.batch_matmul(matrix, v, adj_x=True))
+ apply=lambda v: tf.matmul(matrix, v, adjoint_a=False),
+ apply_adjoint=lambda v: tf.matmul(matrix, v, adjoint_a=True))
# TODO(rmlarsen): Measure if we should just call matmul.